Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions skyrl-tx/tx/tinker/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,19 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
pass

@abstractmethod
def save_sampler_checkpoint(self, output_path, model_id: str) -> None:
"""Save sampler checkpoint to disk as tar.gz.
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Prepare model weights for sampling and optionally save to disk.

Backends that use colocated inference engines should sync weights
in-memory regardless of ``persist``. When ``persist`` is *False*
the backend may skip the expensive disk write and only place a
lightweight marker at ``output_path``.

Args:
output_path: Path to save the checkpoint tar.gz file
model_id: The model identifier
persist: If True, write a full model snapshot to disk.
If False, only sync weights in-memory (hot path).
"""
pass

Expand Down
179 changes: 161 additions & 18 deletions skyrl-tx/tx/tinker/backends/skyrl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Currently supports a single model only.
"""

import asyncio
import os
import tarfile
import tempfile
Expand Down Expand Up @@ -49,7 +50,9 @@
class SkyRLTrainBackendConfig(BaseModel, extra="forbid"):
"""Configuration for the SkyRL-Train backend.

Currently uses default config from skyrl-train.
Note: Currently uses SkyRL's default config for all parameters.
TODO: Implement proper config management to allow Tinker users to override
training and inference parameters via backend_config.
"""

pass
Expand Down Expand Up @@ -97,6 +100,7 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendConfig):
self._trainer: RayPPOTrainer | None = None
self._cfg = None
self._tokenizer = AutoTokenizer.from_pretrained(self.base_model)
self._inference_engine_client = None # InferenceEngineClient for sampling

def has_model(self, model_id: str) -> bool:
return self._model_id == model_id
Expand All @@ -115,9 +119,9 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
# Create placement group
colocate_pg = self._create_colocate_pg()

# Create inference engine client
# Create inference engine client (stored on self for sample())
logger.info(f"Creating {self._cfg.generator.num_inference_engines} inference engines")
inference_engine_client = InferenceEngineClient(
self._inference_engine_client = InferenceEngineClient(
create_ray_wrapped_inference_engines_from_config(self._cfg, colocate_pg, self._tokenizer),
self._tokenizer,
self._cfg,
Expand All @@ -137,7 +141,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
tokenizer=self._tokenizer,
train_dataset=None, # Not needed for tinker API
eval_dataset=None,
inference_engine_client=inference_engine_client,
inference_engine_client=self._inference_engine_client,
generator=None, # TODO(tyler): Update for sampling + RL
colocate_pg=colocate_pg,
)
Expand Down Expand Up @@ -282,7 +286,133 @@ def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
raise NotImplementedError("Sampling not yet supported.")
"""Generate samples using InferenceEngineClient.

NOTE: Weight sync is NOT triggered automatically. The caller must call
save_weights_for_sampler() explicitly before calling sample() if weights
have been updated.
"""
# 1. Validate inference is enabled
if self._inference_engine_client is None:
error = types.ErrorResponse(
error="Sampling not enabled. Inference engines were not initialized (num_inference_engines=0 in SkyRL config).",
status="error",
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}

# 2. Validate single model
unique_models = set(prepared_batch.all_model_ids)
if unique_models != {self._model_id}:
error = types.ErrorResponse(
error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error"
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}

# 3. Sample all prompts in parallel
async def sample_all():
tasks = []
for i in range(len(prepared_batch.all_prompts)):
prompt = prepared_batch.all_prompts[i]
sampling_params = prepared_batch.all_sampling_params[i]

# Pass through common fields; only stop needs name translation
# (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
}
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokens

tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
)
)

return await asyncio.gather(*tasks, return_exceptions=True)
Comment on lines 312 to 340
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The sample method is vulnerable to resource exhaustion leading to a Denial of Service. It iterates through all prompts in prepared_batch.all_prompts (the size of which is controlled by the user-supplied num_samples parameter) and creates a task for each one, then executes all tasks in parallel using asyncio.gather. There is no upper bound on num_samples nor any chunking of the tasks. An attacker can provide an extremely large value for num_samples, causing the application to consume excessive memory and network resources, potentially leading to an Out-Of-Memory (OOM) crash or system instability.

Implement a maximum limit for num_samples and process the sampling tasks in smaller, fixed-size batches (e.g., using a semaphore or by chunking the input list).


# Backend runs in engine subprocess with no event loop
sample_outputs = asyncio.run(sample_all())

# Note: sample_outputs may contain Exception objects (from return_exceptions=True)
# We preserve these to include error messages in responses

# 4. Aggregate results by request
return self._aggregate_sample_results(prepared_batch, sample_outputs)

def _aggregate_sample_results(
self,
prepared_batch: types.PreparedSampleBatch,
sample_outputs: list,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Convert InferenceEngineClient outputs to Tinker format."""
results = {}

for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices:
sequences = []
has_error = False
error_msg = None

for i in range(start_idx, end_idx):
output = sample_outputs[i]

# Check if sampling failed (Exception or None)
if isinstance(output, Exception):
has_error = True
error_msg = f"Sampling failed for sample {i}: {type(output).__name__}: {str(output)}"
logger.error(error_msg)
break
elif output is None:
has_error = True
error_msg = f"Sampling failed for sample {i}: Unknown error (output is None)"
logger.error(error_msg)
break

# Extract tokens and logprobs
response_tokens = output["response_ids"][0]
response_logprobs = (output.get("response_logprobs") or [[]])[0]
stop_reason_raw = output["stop_reasons"][0]

# Map vLLM stop reason to Tinker format
stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length"

# Ensure logprobs exist (critical for RL)
if response_logprobs is None or len(response_logprobs) == 0:
logger.warning("No logprobs returned - filling with zeros")
response_logprobs = [0.0] * len(response_tokens)

sequences.append(
types.GeneratedSequence(
tokens=response_tokens,
logprobs=response_logprobs,
stop_reason=stop_reason,
)
)

if has_error:
results[request_id] = types.ErrorResponse(
error=error_msg or "Unknown sampling error",
status="error",
)
else:
# Note: prompt_logprobs not supported initially
if needs_prompt_logprobs:
logger.warning("Prompt logprobs requested but not yet supported")

results[request_id] = types.SampleOutput(
sequences=sequences,
prompt_logprobs=None,
)

return results

def _validate_model_state(self, model_id: str) -> None:
"""Validate that model exists and is initialized."""
Expand Down Expand Up @@ -332,18 +462,31 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None:

logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")

def save_sampler_checkpoint(self, output_path, model_id: str) -> None:
"""Save sampler checkpoint as tar (model only, no optimizer)."""
self._validate_model_state(model_id)

# Create temp directory for HuggingFace export
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")

# Save in HuggingFace format (model weights + tokenizer only)
self._trainer.dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer)
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Sync weights to colocated inference engines and optionally save to disk.

# Create tar archive
self._create_tar_from_directory(hf_dir, output_path)
The NCCL broadcast always runs so inference engines have the latest
policy weights. When ``persist`` is False (the common hot-path in RL
loops) the expensive HuggingFace model export is skipped entirely.
"""
self._validate_model_state(model_id)

logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
# Always sync weights to inference engines (in-memory NCCL broadcast)
if self._inference_engine_client is not None:
asyncio.run(self._trainer.dispatch.save_weights_for_sampler())
logger.info(f"Synced weights for {model_id} to inference engines via NCCL")

if persist:
# Full HuggingFace model export to disk
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")
self._trainer.dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
self._create_tar_from_directory(hf_dir, output_path)
logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
else:
# Hot path: write a lightweight marker so the engine's checkpoint
# bookkeeping stays consistent. Actual weights live in GPU memory.
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with tarfile.open(output_path, "w"):
pass # empty tar — marker only
logger.info(f"Synced weights for {model_id} (disk save skipped)")
9 changes: 7 additions & 2 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,14 @@ def process_save_weights_for_sampler(
checkpoint_id = Path(request_data.path).name
output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz"

# When the caller provides a sampling_session_seq_id the save is
# transient — weights only need to reach the inference engines, not
# disk. Backends can skip the expensive write in that case.
persist = request_data.sampling_session_seq_id is None

with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER):
self.backend.save_sampler_checkpoint(output_path, model_id)
logger.info(f"Saved LoRA adapter weights for model {model_id} to {output_path}")
self.backend.save_sampler_checkpoint(output_path, model_id, persist=persist)
logger.info(f"Saved sampler checkpoint for model {model_id} to {output_path}")

# Return path=None when using sampling_session_seq_id and seq_id (SDK expects this)
if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None:
Expand Down
Loading