Add sampling support for Tinker SkyRL backend#999
Add sampling support for Tinker SkyRL backend#999tyler-griggs wants to merge 18 commits intomainfrom
Conversation
Add full checkpoint save/load functionality to SkyRLTrainBackend: - save_checkpoint(): Saves model + optimizer + scheduler state as uncompressed tar - load_checkpoint(): Restores full training state from tar checkpoint - save_sampler_checkpoint(): Exports model weights in HuggingFace format for inference Implementation wraps WorkerDispatch checkpoint methods and handles tar packaging. Uses uncompressed tar to avoid 5-10 minute gzip bottleneck on 6-7GB FSDP checkpoints. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Address PR review feedback: 1. Security: Add filter='data' to tarfile.extractall() to prevent path traversal (TarSlip) attacks where malicious archives could write outside the temp directory 2. Refactor: Extract duplicate validation logic into _validate_model_state() helper method (used by all 3 checkpoint methods) 3. Remove redundant os.path.exists() check that creates TOCTOU race condition - tarfile.open() already raises FileNotFoundError 4. Refactor: Extract common tar creation logic into _create_tar_from_directory() helper method to reduce duplication Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Implements sampling/inference capabilities to enable RL training workflows via the Tinker API, allowing rl_loop.py from tinker-cookbook to run on the SkyRL backend. Key changes: - Add inference engine configuration to SkyRLTrainBackendConfig (num_inference_engines, tensor_parallel_size, backend, etc.) - Initialize InferenceEngineClient in create_model() when sampling enabled - Implement sample() method using InferenceEngineClient - Implement save_weights_for_sampler() for in-memory weight sync - Update engine.py to call weight sync before checkpoint save Design: - Two distinct operations: in-memory weight sync vs disk checkpoint save - Weight sync is explicit (not automatic in sample()) - Supports SFT-only mode (num_inference_engines=0) - Supports both colocated and separate GPU configurations - Returns Tinker-format SampleOutput with tokens + logprobs (critical for RL) Enables: - RL training via rl_loop.py - Cross-backend compatibility with Thinking Machines backend - In-flight weight updates for async training Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds sampling and checkpointing support for the SkyRL Tinker backend, a crucial step for enabling RL training workflows. However, it introduces Path Traversal vulnerabilities in the new save_checkpoint and load_checkpoint methods, as these rely on unsanitized user-supplied identifiers (model_id, checkpoint_id), potentially allowing an attacker to read or write files outside the intended checkpoint directory. Additionally, there is a critical performance issue in the sampling implementation and a couple of medium-severity suggestions for improving configuration and maintainability.
Remove custom config fields from SkyRLTrainBackendConfig to keep it consistent with how we handle training config (uses defaults). Changes: - Revert SkyRLTrainBackendConfig to empty (just pass) - Update _create_inference_engines() to read from cfg.generator.* - Update create_model() to check cfg.generator.num_inference_engines - Add TODO comment about future config management Long-term: Implement proper config mapping to allow Tinker users to override SkyRL params via --backend_config. See TODO in ~/claude-docs/skyrl/project-summary.md for design options. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Use asyncio.gather to run all sampling tasks concurrently instead of calling asyncio.run() in a loop for each prompt. Before: Sequential sampling (1 event loop per prompt) After: Parallel sampling (1 event loop, all prompts at once) This significantly improves sampling performance when multiple prompts are batched together. Addresses review feedback from PR #999. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
/gemini review |
Resolved conflicts in skyrl-tx/tx/tinker/backends/skyrl_train.py: - Kept import asyncio (needed for parallel sampling) - Removed duplicate method definitions (_validate_model_state, _create_tar_from_directory) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces significant and well-structured functionality for sampling and inference to the SkyRL backend, which is a crucial step for enabling RL workflows. The code is generally of high quality with good documentation and clear separation of concerns, particularly between in-memory weight synchronization and disk-based checkpointing. My review focuses on a critical compatibility/security issue related to tarfile, improving error handling during sampling, enhancing code readability, and adhering to Python's standard style guidelines.
I am having trouble creating individual review comments. Click here to see my feedback.
skyrl-tx/tx/tinker/backends/skyrl_train.py (462)
The filter argument for tarfile.extractall was introduced in Python 3.12 to mitigate path traversal vulnerabilities. If this project supports Python versions older than 3.12, this call will raise a TypeError. Please ensure the project's minimum Python version is 3.12, or provide a version-conditional fallback for older versions. For example:
import sys
if sys.version_info >= (3, 12):
tar.extractall(temp_dir, filter="data")
else:
# This is insecure, but maintains compatibility.
# Consider adding your own checks to prevent path traversal.
tar.extractall(temp_dir)skyrl-tx/tx/tinker/backends/skyrl_train.py (340-342)
Converting exceptions to None here causes the loss of valuable error information. It would be better to pass the sample_outputs list (which may contain Exception objects) directly to _aggregate_sample_results and handle the exceptions there to generate more informative error messages.
For example, in _aggregate_sample_results, you could check for exceptions and format the error message accordingly:
if isinstance(output, Exception):
has_error = True
error_msg = f"Sampling failed for sample {i} with error: {output!r}"
breakskyrl-tx/tx/tinker/backends/skyrl_train.py (370)
This line is a bit complex and can be hard to read. Simplifying it would improve maintainability.
response_logprobs = (output.get("response_logprobs") or [[]])[0]
skyrl-tx/tx/tinker/engine.py (495)
According to PEP 8, imports should usually be at the top of the file. Please move import asyncio to the top of the file to improve code style and consistency.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds sampling support to the SkyRL Tinker backend, a crucial feature for enabling reinforcement learning workflows. The changes are well-structured, introducing a new sample method, helpers for creating inference engines, and an async save_weights_for_sampler method for in-memory weight updates. The implementation aligns well with the detailed design in the pull request description. I've identified one critical issue in the sampling logic that would cause a runtime error. Addressing this will ensure the new functionality works as intended.
| if sampling_params.top_k is not None: | ||
| params_dict["top_k"] = sampling_params.top_k | ||
| if sampling_params.top_p is not None: | ||
| params_dict["top_p"] = sampling_params.top_p | ||
| if sampling_params.stop: | ||
| params_dict["stop"] = sampling_params.stop |
There was a problem hiding this comment.
The code attempts to access sampling_params.stop, but this attribute does not exist on the types.SamplingParams model, which will cause a runtime AttributeError. The correct attributes for stop criteria are stop_strings and stop_tokens. This suggestion fixes the handling of stop sequences while preserving the correct logic for top_k and top_p.
if sampling_params.top_k is not None:
params_dict["top_k"] = sampling_params.top_k
if sampling_params.top_p is not None:
params_dict["top_p"] = sampling_params.top_p
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokensFixes four critical issues identified in code review: 1. AttributeError on sampling_params.stop - Fixed: Use stop_tokens and stop_strings instead of .stop attribute - Now properly combines both into single stop list for vLLM 2. Stop tokens/strings not forwarded - Fixed: Combine stop_strings and stop_tokens into stop list - vLLM accepts both strings and token IDs in same list 3. Seeds not forwarded - Fixed: Add seed to params_dict for deterministic sampling 4. asyncio.run() in async context crashes - Fixed: Check for running event loop with get_running_loop() - Use run_coroutine_threadsafe if in async context (FastAPI) - Use asyncio.run() only if no running loop - Applied fix in both sample() and process_save_weights_for_sampler() Additional improvements: - Convert top_k=-1 to None (vLLM convention for no limit) - Better error handling with return_exceptions=True These fixes prevent crashes and enable core functionality like stop conditions, reproducible sampling, and FastAPI compatibility. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Fix issue where exceptions from InferenceEngineClient.sample() were converted to None, losing the original error message. This made debugging impossible. Changes: - Stop converting exceptions to None in sample() - Preserve Exception objects in sample_outputs list - Check for isinstance(output, Exception) in aggregation - Extract and include exception type and message in error response - Log full error message for debugging Before: "Sampling failed for sample X" (generic, no context) After: "Sampling failed for sample X: RuntimeError: CUDA out of memory" (specific, actionable) This addresses code review feedback about exception swallowing. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Remove unnecessary event loop checking in engine.py since the engine runs in a separate subprocess with no event loop. The engine is spawned as a subprocess from FastAPI (api.py:lifespan) and runs a sync blocking loop. There is never a running event loop in engine.py, so asyncio.run() is always safe to use directly. The complex try/except with get_running_loop() was unnecessary - it would always take the RuntimeError path anyway. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Remove redundant comment about subprocess/event loop - Clarify that save_weights_for_sampler is SkyRL-only - Fix log message: "sampler checkpoint" not "LoRA adapter weights" (SkyRL may use full fine-tuning, not just LoRA) - Keep save_sampler_checkpoint call for all backends (needed for cross-service compatibility and API contract) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request successfully adds sampling support for the SkyRL Tinker backend, enabling RL training workflows. However, two significant security issues were identified: a flawed concurrency model in the sample method that leads to deadlocks in asynchronous environments, and a lack of resource constraints when processing large batches of samples, which can be exploited for Denial of Service attacks. The implementation of the sample() method using InferenceEngineClient is otherwise well-done, including robust handling of asynchronous operations and batch processing, and save_weights_for_sampler is a key feature. Minor suggestions for code style and readability have also been provided.
| try: | ||
| loop = asyncio.get_running_loop() | ||
| # We're in an async context (e.g., FastAPI) - use run_coroutine_threadsafe | ||
| import concurrent.futures | ||
|
|
||
| future = asyncio.run_coroutine_threadsafe(sample_all(), loop) | ||
| sample_outputs = future.result() | ||
| except RuntimeError: | ||
| # No running loop - safe to use asyncio.run() | ||
| sample_outputs = asyncio.run(sample_all()) | ||
|
|
There was a problem hiding this comment.
The sample method contains logic that will cause a permanent deadlock if called from a thread where an asyncio event loop is already running. The implementation uses asyncio.get_running_loop() and asyncio.run_coroutine_threadsafe(sample_all(), loop).result(), which blocks the event loop thread, leading to a permanent hang of the process. To remediate this, avoid blocking the event loop thread; consider making the sample method asynchronous or offloading work to a separate thread pool. Additionally, for better code organization and PEP 8 adherence, the concurrent.futures import should be moved to the top of the file.
| async def sample_all(): | ||
| tasks = [] | ||
| for i in range(len(prepared_batch.all_prompts)): | ||
| prompt = prepared_batch.all_prompts[i] | ||
| sampling_params = prepared_batch.all_sampling_params[i] | ||
|
|
||
| # Convert to InferenceEngineClient format | ||
| params_dict = { | ||
| "temperature": sampling_params.temperature, | ||
| "max_tokens": sampling_params.max_tokens, | ||
| "seed": sampling_params.seed, | ||
| } | ||
|
|
||
| # Handle top_k: convert -1 to None (vLLM expects None for no limit) | ||
| if sampling_params.top_k is not None and sampling_params.top_k != -1: | ||
| params_dict["top_k"] = sampling_params.top_k | ||
|
|
||
| if sampling_params.top_p is not None: | ||
| params_dict["top_p"] = sampling_params.top_p | ||
|
|
||
| # Handle stop conditions: combine stop_tokens and stop_strings | ||
| # vLLM accepts both strings and token IDs in the same list | ||
| stop = [] | ||
| if sampling_params.stop_strings: | ||
| stop.extend(sampling_params.stop_strings) | ||
| if sampling_params.stop_tokens: | ||
| stop.extend(sampling_params.stop_tokens) | ||
| if stop: | ||
| params_dict["stop"] = stop | ||
|
|
||
| tasks.append( | ||
| self._inference_engine_client.sample( | ||
| prompt_token_ids=prompt, | ||
| num_samples=1, # Tinker batches multiple samples separately | ||
| sampling_params=params_dict, | ||
| ) | ||
| ) | ||
|
|
||
| return await asyncio.gather(*tasks, return_exceptions=True) |
There was a problem hiding this comment.
The sample method is vulnerable to resource exhaustion leading to a Denial of Service. It iterates through all prompts in prepared_batch.all_prompts (the size of which is controlled by the user-supplied num_samples parameter) and creates a task for each one, then executes all tasks in parallel using asyncio.gather. There is no upper bound on num_samples nor any chunking of the tasks. An attacker can provide an extremely large value for num_samples, causing the application to consume excessive memory and network resources, potentially leading to an Out-Of-Memory (OOM) crash or system instability.
Implement a maximum limit for num_samples and process the sampling tasks in smaller, fixed-size batches (e.g., using a semaphore or by chunking the input list).
|
|
||
| # Extract tokens and logprobs | ||
| response_tokens = output["response_ids"][0] | ||
| response_logprobs = output.get("response_logprobs", [[]])[0] if output.get("response_logprobs") else [] |
There was a problem hiding this comment.
This line is a bit dense and could be simplified for better readability. Using the or operator can make handling the default case more concise and clear.
| response_logprobs = output.get("response_logprobs", [[]])[0] if output.get("response_logprobs") else [] | |
| response_logprobs = (output.get("response_logprobs") or [[]])[0] |
skyrl-tx/tx/tinker/engine.py
Outdated
| with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER): | ||
| # Sync weights to inference engines if backend supports it (SkyRL only) | ||
| if hasattr(self.backend, "save_weights_for_sampler"): | ||
| import asyncio |
- Remove unnecessary async context detection (backend only runs in subprocess) - Simplify logprobs extraction using 'or' operator - Remove unused concurrent.futures import - Add Qwen2.5-0.5B tensor_parallel_size workaround for testing Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
66efcc0 to
e080d59
Compare
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
| cfg.trainer.policy.optimizer_config.num_warmup_steps = 0 | ||
|
|
||
| # Workaround: Qwen2.5-0.5B has 14 attention heads, use tensor_parallel_size=2 | ||
| if "Qwen2.5-0.5B" in base_model: |
There was a problem hiding this comment.
Should we remove this before merging? It would probably be better to expose the .yaml so it can be set in there, right? You could add an argument to SkyRLTrainBackendConfig which specifies the path to the .yaml and then load it from here -- and then it can be used like `--backend-config '{"config_path": "/path/to/yaml"}' or similar.
There was a problem hiding this comment.
Yeah I'll remove! I'm hoping to rely on Sumanth's changes to use pythonic configs very soon..
| Reads configuration from cfg.generator.* (SkyRL's config system). | ||
| TODO: Allow overrides via backend_config once config management is implemented. | ||
| """ | ||
| from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines |
There was a problem hiding this comment.
I realize trying to avoid importing skyrl_train will get uglier and uglier over time. We need to figure out a better way, we can make this a function
SkyRL/skyrl-tx/tx/tinker/engine.py
Line 133 in afa60b7
def get_backends() -> dict[str, tuple[...]]:
backends = {}
try:
from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig
backends["skyrl_train"] = (SkyRLTrainBackend, SkyRLTrainBackendConfig)
except ImportError:
# print warning
try:
from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig
backends["jax"] = (JaxBackend, JaxBackendConfig)
except:
# print warning
return backendsMaybe this can be made even nicer...
|
|
||
| colocate_all = cfg.trainer.placement.colocate_all if hasattr(cfg.trainer.placement, "colocate_all") else False | ||
|
|
||
| engine_kwargs = { |
There was a problem hiding this comment.
Could this code be structured such that cfg.generator is mostly passed through? That would make it easier going forward e.g. when new fields are added. These fields are currently implementation details that this file should better not be concerned with (the more minimal we can keep this file, the better).
| "seed": sampling_params.seed, | ||
| } | ||
|
|
||
| # Handle top_k: convert -1 to None (vLLM expects None for no limit) |
There was a problem hiding this comment.
It is worth checking if these are actually needed (e.g. https://github.com/vllm-project/vllm/blob/8869cd8ec1b2586df96e6556f2d2e3b3760bfb93/vllm/sampling_params.py#L159). In external_inference.py we are just passing it through and it has been working well afaik.
| if sampling_params.top_p is not None: | ||
| params_dict["top_p"] = sampling_params.top_p | ||
|
|
||
| # Handle stop conditions: combine stop_tokens and stop_strings |
There was a problem hiding this comment.
Do these need to be handled differently? E.g. stop vs stop_token_ids https://github.com/vllm-project/vllm/blob/8869cd8ec1b2586df96e6556f2d2e3b3760bfb93/vllm/sampling_params.py#L171
When the SDK uses save_weights_and_get_sampling_client() (session-based pattern), weights only need to reach inference engines via NCCL — the expensive HuggingFace model export to disk can be skipped entirely. Changes: - Add persist parameter to save_sampler_checkpoint() abstract interface - Engine derives persist from sampling_session_seq_id presence - SkyRL backend: always NCCL sync, conditionally write to disk - Fix stop token handling: separate stop strings vs stop_token_ids - Fix save_hf_model parameter name (hf_model_dir → export_dir) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace the local _create_inference_engines function (which manually extracted ~15 fields from cfg.generator) with a call to skyrl-train's create_ray_wrapped_inference_engines_from_config. This keeps the backend decoupled from generator config details — new fields are automatically forwarded without changes here. Also simplify the model validation check per review feedback. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove hardcoded Qwen2.5-0.5B workaround from _build_config - Lazy-import backends in engine.py to avoid requiring all backend dependencies at import time (only the selected backend is imported) - Simplify sampling params: pass top_k/top_p through directly instead of filtering defaults; only translate stop field names
Resolve conflicts in skyrl_train.py: keep sampling implementation, persist flag, and lazy imports while adopting main's RayPPOTrainer architecture (self._trainer.dispatch.*). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Implements sampling/inference support for the SkyRL Tinker backend so we can start enabling RL training workflows via the Tinker API.
Key Changes
sample()method inSkyRLTrainBackendusing existingInferenceEngineClientsave_weights_for_sampler(), which just passes it along to theWorkerDispatchLimitations
Nonewith warning)