Skip to content

Add sampling support for Tinker SkyRL backend#999

Open
tyler-griggs wants to merge 18 commits intomainfrom
tyler/tinker-sampling-main
Open

Add sampling support for Tinker SkyRL backend#999
tyler-griggs wants to merge 18 commits intomainfrom
tyler/tinker-sampling-main

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jan 31, 2026

Summary

Implements sampling/inference support for the SkyRL Tinker backend so we can start enabling RL training workflows via the Tinker API.

Key Changes

  • Implements sample() method in SkyRLTrainBackend using existing InferenceEngineClient
  • Adds weight synchronization via save_weights_for_sampler(), which just passes it along to the WorkerDispatch

Limitations

  • Prompt logprobs not yet supported (returns None with warning)
  • Checkpoint paths in requests ignored (always uses current weights)
  • Single model/checkpoint at a time

tyler-griggs and others added 5 commits January 30, 2026 01:49
Add full checkpoint save/load functionality to SkyRLTrainBackend:

- save_checkpoint(): Saves model + optimizer + scheduler state as uncompressed tar
- load_checkpoint(): Restores full training state from tar checkpoint
- save_sampler_checkpoint(): Exports model weights in HuggingFace format for inference

Implementation wraps WorkerDispatch checkpoint methods and handles tar packaging.
Uses uncompressed tar to avoid 5-10 minute gzip bottleneck on 6-7GB FSDP checkpoints.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Address PR review feedback:

1. Security: Add filter='data' to tarfile.extractall() to prevent path
   traversal (TarSlip) attacks where malicious archives could write
   outside the temp directory

2. Refactor: Extract duplicate validation logic into _validate_model_state()
   helper method (used by all 3 checkpoint methods)

3. Remove redundant os.path.exists() check that creates TOCTOU race
   condition - tarfile.open() already raises FileNotFoundError

4. Refactor: Extract common tar creation logic into
   _create_tar_from_directory() helper method to reduce duplication

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Implements sampling/inference capabilities to enable RL training workflows
via the Tinker API, allowing rl_loop.py from tinker-cookbook to run on
the SkyRL backend.

Key changes:
- Add inference engine configuration to SkyRLTrainBackendConfig
  (num_inference_engines, tensor_parallel_size, backend, etc.)
- Initialize InferenceEngineClient in create_model() when sampling enabled
- Implement sample() method using InferenceEngineClient
- Implement save_weights_for_sampler() for in-memory weight sync
- Update engine.py to call weight sync before checkpoint save

Design:
- Two distinct operations: in-memory weight sync vs disk checkpoint save
- Weight sync is explicit (not automatic in sample())
- Supports SFT-only mode (num_inference_engines=0)
- Supports both colocated and separate GPU configurations
- Returns Tinker-format SampleOutput with tokens + logprobs (critical for RL)

Enables:
- RL training via rl_loop.py
- Cross-backend compatibility with Thinking Machines backend
- In-flight weight updates for async training

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@tyler-griggs
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds sampling and checkpointing support for the SkyRL Tinker backend, a crucial step for enabling RL training workflows. However, it introduces Path Traversal vulnerabilities in the new save_checkpoint and load_checkpoint methods, as these rely on unsanitized user-supplied identifiers (model_id, checkpoint_id), potentially allowing an attacker to read or write files outside the intended checkpoint directory. Additionally, there is a critical performance issue in the sampling implementation and a couple of medium-severity suggestions for improving configuration and maintainability.

Remove custom config fields from SkyRLTrainBackendConfig to keep it
consistent with how we handle training config (uses defaults).

Changes:
- Revert SkyRLTrainBackendConfig to empty (just pass)
- Update _create_inference_engines() to read from cfg.generator.*
- Update create_model() to check cfg.generator.num_inference_engines
- Add TODO comment about future config management

Long-term: Implement proper config mapping to allow Tinker users to
override SkyRL params via --backend_config. See TODO in
~/claude-docs/skyrl/project-summary.md for design options.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Use asyncio.gather to run all sampling tasks concurrently instead of
calling asyncio.run() in a loop for each prompt.

Before: Sequential sampling (1 event loop per prompt)
After: Parallel sampling (1 event loop, all prompts at once)

This significantly improves sampling performance when multiple prompts
are batched together.

Addresses review feedback from PR #999.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@tyler-griggs
Copy link
Member Author

/gemini review

Resolved conflicts in skyrl-tx/tx/tinker/backends/skyrl_train.py:
- Kept import asyncio (needed for parallel sampling)
- Removed duplicate method definitions (_validate_model_state, _create_tar_from_directory)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant and well-structured functionality for sampling and inference to the SkyRL backend, which is a crucial step for enabling RL workflows. The code is generally of high quality with good documentation and clear separation of concerns, particularly between in-memory weight synchronization and disk-based checkpointing. My review focuses on a critical compatibility/security issue related to tarfile, improving error handling during sampling, enhancing code readability, and adhering to Python's standard style guidelines.

I am having trouble creating individual review comments. Click here to see my feedback.

skyrl-tx/tx/tinker/backends/skyrl_train.py (462)

critical

The filter argument for tarfile.extractall was introduced in Python 3.12 to mitigate path traversal vulnerabilities. If this project supports Python versions older than 3.12, this call will raise a TypeError. Please ensure the project's minimum Python version is 3.12, or provide a version-conditional fallback for older versions. For example:

import sys

if sys.version_info >= (3, 12):
    tar.extractall(temp_dir, filter="data")
else:
    # This is insecure, but maintains compatibility.
    # Consider adding your own checks to prevent path traversal.
    tar.extractall(temp_dir)

skyrl-tx/tx/tinker/backends/skyrl_train.py (340-342)

medium

Converting exceptions to None here causes the loss of valuable error information. It would be better to pass the sample_outputs list (which may contain Exception objects) directly to _aggregate_sample_results and handle the exceptions there to generate more informative error messages.

For example, in _aggregate_sample_results, you could check for exceptions and format the error message accordingly:

if isinstance(output, Exception):
    has_error = True
    error_msg = f"Sampling failed for sample {i} with error: {output!r}"
    break

skyrl-tx/tx/tinker/backends/skyrl_train.py (370)

medium

This line is a bit complex and can be hard to read. Simplifying it would improve maintainability.

                response_logprobs = (output.get("response_logprobs") or [[]])[0]

skyrl-tx/tx/tinker/engine.py (495)

medium

According to PEP 8, imports should usually be at the top of the file. Please move import asyncio to the top of the file to improve code style and consistency.

@tyler-griggs
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds sampling support to the SkyRL Tinker backend, a crucial feature for enabling reinforcement learning workflows. The changes are well-structured, introducing a new sample method, helpers for creating inference engines, and an async save_weights_for_sampler method for in-memory weight updates. The implementation aligns well with the detailed design in the pull request description. I've identified one critical issue in the sampling logic that would cause a runtime error. Addressing this will ensure the new functionality works as intended.

Comment on lines 320 to 325
if sampling_params.top_k is not None:
params_dict["top_k"] = sampling_params.top_k
if sampling_params.top_p is not None:
params_dict["top_p"] = sampling_params.top_p
if sampling_params.stop:
params_dict["stop"] = sampling_params.stop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code attempts to access sampling_params.stop, but this attribute does not exist on the types.SamplingParams model, which will cause a runtime AttributeError. The correct attributes for stop criteria are stop_strings and stop_tokens. This suggestion fixes the handling of stop sequences while preserving the correct logic for top_k and top_p.

                if sampling_params.top_k is not None:
                    params_dict["top_k"] = sampling_params.top_k
                if sampling_params.top_p is not None:
                    params_dict["top_p"] = sampling_params.top_p
                if sampling_params.stop_strings:
                    params_dict["stop"] = sampling_params.stop_strings
                if sampling_params.stop_tokens:
                    params_dict["stop_token_ids"] = sampling_params.stop_tokens

Fixes four critical issues identified in code review:

1. AttributeError on sampling_params.stop
   - Fixed: Use stop_tokens and stop_strings instead of .stop attribute
   - Now properly combines both into single stop list for vLLM

2. Stop tokens/strings not forwarded
   - Fixed: Combine stop_strings and stop_tokens into stop list
   - vLLM accepts both strings and token IDs in same list

3. Seeds not forwarded
   - Fixed: Add seed to params_dict for deterministic sampling

4. asyncio.run() in async context crashes
   - Fixed: Check for running event loop with get_running_loop()
   - Use run_coroutine_threadsafe if in async context (FastAPI)
   - Use asyncio.run() only if no running loop
   - Applied fix in both sample() and process_save_weights_for_sampler()

Additional improvements:
- Convert top_k=-1 to None (vLLM convention for no limit)
- Better error handling with return_exceptions=True

These fixes prevent crashes and enable core functionality like stop
conditions, reproducible sampling, and FastAPI compatibility.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Fix issue where exceptions from InferenceEngineClient.sample() were
converted to None, losing the original error message. This made
debugging impossible.

Changes:
- Stop converting exceptions to None in sample()
- Preserve Exception objects in sample_outputs list
- Check for isinstance(output, Exception) in aggregation
- Extract and include exception type and message in error response
- Log full error message for debugging

Before: "Sampling failed for sample X" (generic, no context)
After: "Sampling failed for sample X: RuntimeError: CUDA out of memory" (specific, actionable)

This addresses code review feedback about exception swallowing.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Remove unnecessary event loop checking in engine.py since the engine
runs in a separate subprocess with no event loop.

The engine is spawned as a subprocess from FastAPI (api.py:lifespan)
and runs a sync blocking loop. There is never a running event loop in
engine.py, so asyncio.run() is always safe to use directly.

The complex try/except with get_running_loop() was unnecessary - it
would always take the RuntimeError path anyway.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Remove redundant comment about subprocess/event loop
- Clarify that save_weights_for_sampler is SkyRL-only
- Fix log message: "sampler checkpoint" not "LoRA adapter weights"
  (SkyRL may use full fine-tuning, not just LoRA)
- Keep save_sampler_checkpoint call for all backends (needed for
  cross-service compatibility and API contract)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds sampling support for the SkyRL Tinker backend, enabling RL training workflows. However, two significant security issues were identified: a flawed concurrency model in the sample method that leads to deadlocks in asynchronous environments, and a lack of resource constraints when processing large batches of samples, which can be exploited for Denial of Service attacks. The implementation of the sample() method using InferenceEngineClient is otherwise well-done, including robust handling of asynchronous operations and batch processing, and save_weights_for_sampler is a key feature. Minor suggestions for code style and readability have also been provided.

Comment on lines 347 to 357
try:
loop = asyncio.get_running_loop()
# We're in an async context (e.g., FastAPI) - use run_coroutine_threadsafe
import concurrent.futures

future = asyncio.run_coroutine_threadsafe(sample_all(), loop)
sample_outputs = future.result()
except RuntimeError:
# No running loop - safe to use asyncio.run()
sample_outputs = asyncio.run(sample_all())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The sample method contains logic that will cause a permanent deadlock if called from a thread where an asyncio event loop is already running. The implementation uses asyncio.get_running_loop() and asyncio.run_coroutine_threadsafe(sample_all(), loop).result(), which blocks the event loop thread, leading to a permanent hang of the process. To remediate this, avoid blocking the event loop thread; consider making the sample method asynchronous or offloading work to a separate thread pool. Additionally, for better code organization and PEP 8 adherence, the concurrent.futures import should be moved to the top of the file.

Comment on lines 306 to 344
async def sample_all():
tasks = []
for i in range(len(prepared_batch.all_prompts)):
prompt = prepared_batch.all_prompts[i]
sampling_params = prepared_batch.all_sampling_params[i]

# Convert to InferenceEngineClient format
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
}

# Handle top_k: convert -1 to None (vLLM expects None for no limit)
if sampling_params.top_k is not None and sampling_params.top_k != -1:
params_dict["top_k"] = sampling_params.top_k

if sampling_params.top_p is not None:
params_dict["top_p"] = sampling_params.top_p

# Handle stop conditions: combine stop_tokens and stop_strings
# vLLM accepts both strings and token IDs in the same list
stop = []
if sampling_params.stop_strings:
stop.extend(sampling_params.stop_strings)
if sampling_params.stop_tokens:
stop.extend(sampling_params.stop_tokens)
if stop:
params_dict["stop"] = stop

tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
)
)

return await asyncio.gather(*tasks, return_exceptions=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The sample method is vulnerable to resource exhaustion leading to a Denial of Service. It iterates through all prompts in prepared_batch.all_prompts (the size of which is controlled by the user-supplied num_samples parameter) and creates a task for each one, then executes all tasks in parallel using asyncio.gather. There is no upper bound on num_samples nor any chunking of the tasks. An attacker can provide an extremely large value for num_samples, causing the application to consume excessive memory and network resources, potentially leading to an Out-Of-Memory (OOM) crash or system instability.

Implement a maximum limit for num_samples and process the sampling tasks in smaller, fixed-size batches (e.g., using a semaphore or by chunking the input list).


# Extract tokens and logprobs
response_tokens = output["response_ids"][0]
response_logprobs = output.get("response_logprobs", [[]])[0] if output.get("response_logprobs") else []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is a bit dense and could be simplified for better readability. Using the or operator can make handling the default case more concise and clear.

Suggested change
response_logprobs = output.get("response_logprobs", [[]])[0] if output.get("response_logprobs") else []
response_logprobs = (output.get("response_logprobs") or [[]])[0]

with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER):
# Sync weights to inference engines if backend supports it (SkyRL only)
if hasattr(self.backend, "save_weights_for_sampler"):
import asyncio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To follow standard Python style guidelines (PEP 8), please move this import to the top of the file. This improves readability and makes dependencies clear.

- Remove unnecessary async context detection (backend only runs in subprocess)
- Simplify logprobs extraction using 'or' operator
- Remove unused concurrent.futures import
- Add Qwen2.5-0.5B tensor_parallel_size workaround for testing

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
cfg.trainer.policy.optimizer_config.num_warmup_steps = 0

# Workaround: Qwen2.5-0.5B has 14 attention heads, use tensor_parallel_size=2
if "Qwen2.5-0.5B" in base_model:
Copy link
Collaborator

@pcmoritz pcmoritz Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this before merging? It would probably be better to expose the .yaml so it can be set in there, right? You could add an argument to SkyRLTrainBackendConfig which specifies the path to the .yaml and then load it from here -- and then it can be used like `--backend-config '{"config_path": "/path/to/yaml"}' or similar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'll remove! I'm hoping to rely on Sumanth's changes to use pythonic configs very soon..

Reads configuration from cfg.generator.* (SkyRL's config system).
TODO: Allow overrides via backend_config once config management is implemented.
"""
from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize trying to avoid importing skyrl_train will get uglier and uglier over time. We need to figure out a better way, we can make this a function

BACKENDS = {
that returns the available backends, and it will only import skyrl_train if the dependencies are installed, so everything is in one centralized place. E.g.

def get_backends() -> dict[str, tuple[...]]:
    backends = {}
    try:
        from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig
        backends["skyrl_train"] = (SkyRLTrainBackend, SkyRLTrainBackendConfig)
    except ImportError:
        # print warning

    try:
        from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig
        backends["jax"] = (JaxBackend, JaxBackendConfig)
    except:
        # print warning

    return backends

Maybe this can be made even nicer...


colocate_all = cfg.trainer.placement.colocate_all if hasattr(cfg.trainer.placement, "colocate_all") else False

engine_kwargs = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this code be structured such that cfg.generator is mostly passed through? That would make it easier going forward e.g. when new fields are added. These fields are currently implementation details that this file should better not be concerned with (the more minimal we can keep this file, the better).

"seed": sampling_params.seed,
}

# Handle top_k: convert -1 to None (vLLM expects None for no limit)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth checking if these are actually needed (e.g. https://github.com/vllm-project/vllm/blob/8869cd8ec1b2586df96e6556f2d2e3b3760bfb93/vllm/sampling_params.py#L159). In external_inference.py we are just passing it through and it has been working well afaik.

if sampling_params.top_p is not None:
params_dict["top_p"] = sampling_params.top_p

# Handle stop conditions: combine stop_tokens and stop_strings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tyler-griggs and others added 4 commits February 4, 2026 00:15
When the SDK uses save_weights_and_get_sampling_client() (session-based
pattern), weights only need to reach inference engines via NCCL — the
expensive HuggingFace model export to disk can be skipped entirely.

Changes:
- Add persist parameter to save_sampler_checkpoint() abstract interface
- Engine derives persist from sampling_session_seq_id presence
- SkyRL backend: always NCCL sync, conditionally write to disk
- Fix stop token handling: separate stop strings vs stop_token_ids
- Fix save_hf_model parameter name (hf_model_dir → export_dir)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace the local _create_inference_engines function (which manually
extracted ~15 fields from cfg.generator) with a call to skyrl-train's
create_ray_wrapped_inference_engines_from_config. This keeps the backend
decoupled from generator config details — new fields are automatically
forwarded without changes here.

Also simplify the model validation check per review feedback.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove hardcoded Qwen2.5-0.5B workaround from _build_config
- Lazy-import backends in engine.py to avoid requiring all backend
  dependencies at import time (only the selected backend is imported)
- Simplify sampling params: pass top_k/top_p through directly instead
  of filtering defaults; only translate stop field names
Resolve conflicts in skyrl_train.py: keep sampling implementation,
persist flag, and lazy imports while adopting main's RayPPOTrainer
architecture (self._trainer.dispatch.*).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants