From 29dd0add12bb40b5d82138500b066a51877519cb Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Thu, 29 Jan 2026 02:09:58 +0000 Subject: [PATCH 1/2] Implement checkpointing for Tinker SkyRL backend Add support for saving and loading training checkpoints, enabling sl_loop.py to run unchanged with full checkpoint/resume functionality. Key changes: - save_checkpoint(): Saves full training state (model + optimizer + scheduler) as tar.gz archive - load_checkpoint(): Loads training state from tar.gz and restores optimizer and scheduler states for seamless resume - save_sampler_checkpoint(): Exports HuggingFace format for inference/sampling (model only, no optimizer) Implementation leverages existing WorkerDispatch checkpoint methods: - Uses FSDP's distributed checkpoint format (per-rank sharded files) - Automatically includes LoRA adapter state - Preserves RNG state for reproducibility This enables: - Periodic checkpoint saves during training - Resume training from last checkpoint - Optimizer state preservation (no loss spikes on resume) Co-Authored-By: Claude Sonnet 4.5 --- .../tinker/backends/skyrl_train.py | 72 ++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py index 3b32eeaf9..a2eba7862 100644 --- a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py +++ b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py @@ -6,6 +6,9 @@ print("[DEBUG] skyrl_train.py: Starting imports...", flush=True) +import os +import tarfile +import tempfile from typing import Any import torch @@ -246,10 +249,73 @@ def sample( raise NotImplementedError("Sampling not supported") def save_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Saving checkpoints not supported") + """Save full training checkpoint (model + optimizer + scheduler) as tar.gz.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Create temp directory for checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_dir = os.path.join(temp_dir, "checkpoint") + + # Save checkpoint directory (includes optimizer state automatically) + self._dispatch.save_checkpoint( + model="policy", + ckpt_dir=ckpt_dir, + tokenizer=self._tokenizer + ) + + # Create tar.gz archive + with tarfile.open(output_path, "w:gz") as tar: + tar.add(ckpt_dir, arcname=".") + + logger.info(f"Saved checkpoint for {model_id} to {output_path}") def load_checkpoint(self, checkpoint_path, model_id: str) -> None: - raise NotImplementedError("Loading checkpoints not supported") + """Load full training checkpoint (model + optimizer + scheduler) from tar.gz.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Extract tar.gz to temp directory + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(checkpoint_path, "r:gz") as tar: + tar.extractall(temp_dir) + + # Load checkpoint (includes optimizer and scheduler states) + self._dispatch.load_checkpoint( + model="policy", + ckpt_dir=temp_dir, + load_optimizer_states=True, + load_lr_scheduler_states=True + ) + + logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") def save_sampler_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Sampler checkpoints not supported") + """Save sampler checkpoint as tar.gz (model only, no optimizer).""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Create temp directory for HuggingFace export + with tempfile.TemporaryDirectory() as temp_dir: + hf_dir = os.path.join(temp_dir, "model") + + # Save in HuggingFace format (model weights + tokenizer only) + self._dispatch.save_hf_model( + model="policy", + hf_model_dir=hf_dir, + tokenizer=self._tokenizer + ) + + # Create tar.gz archive + with tarfile.open(output_path, "w:gz") as tar: + tar.add(hf_dir, arcname=".") + + logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}") From 181518b7e99a3388972c88db8b839ab0dc35294d Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Thu, 29 Jan 2026 17:56:12 +0000 Subject: [PATCH 2/2] Perf: Use uncompressed tar for checkpoints Checkpoints are already large (6-7GB with FSDP sharding), and gzip compression adds 5-10 minutes of single-threaded CPU time that blocks training. Uncompressed tar is much faster. Future optimization: move checkpoint saving to async background thread. Co-Authored-By: Claude Sonnet 4.5 --- .../skyrl_train/tinker/backends/skyrl_train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py index a2eba7862..c99a4339d 100644 --- a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py +++ b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py @@ -266,8 +266,10 @@ def save_checkpoint(self, output_path, model_id: str) -> None: tokenizer=self._tokenizer ) - # Create tar.gz archive - with tarfile.open(output_path, "w:gz") as tar: + # Create tar archive (uncompressed for speed) + # FSDP checkpoints are already large (6-7GB). Gzip compression adds + # 5-10 minutes of single-threaded CPU time that blocks training. + with tarfile.open(output_path, "w") as tar: tar.add(ckpt_dir, arcname=".") logger.info(f"Saved checkpoint for {model_id} to {output_path}") @@ -281,9 +283,9 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - # Extract tar.gz to temp directory + # Extract tar to temp directory (auto-detects compression) with tempfile.TemporaryDirectory() as temp_dir: - with tarfile.open(checkpoint_path, "r:gz") as tar: + with tarfile.open(checkpoint_path, "r") as tar: tar.extractall(temp_dir) # Load checkpoint (includes optimizer and scheduler states) @@ -314,8 +316,8 @@ def save_sampler_checkpoint(self, output_path, model_id: str) -> None: tokenizer=self._tokenizer ) - # Create tar.gz archive - with tarfile.open(output_path, "w:gz") as tar: + # Create tar archive (uncompressed for speed) + with tarfile.open(output_path, "w") as tar: tar.add(hf_dir, arcname=".") logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")