Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e251093
Implement checkpointing for Tinker SkyRL backend
tyler-griggs Jan 30, 2026
4125427
Format: Consolidate function call arguments to single line
tyler-griggs Jan 30, 2026
84feed7
Shorten checkpoint compression comment
tyler-griggs Jan 30, 2026
92bb84f
Security and refactoring improvements for checkpointing
tyler-griggs Jan 30, 2026
6e88b30
Add sampling support for Tinker SkyRL backend
tyler-griggs Jan 31, 2026
0d27131
Use SkyRL default config for inference engines (consistency fix)
tyler-griggs Jan 31, 2026
cf90f35
Fix sequential sampling - run all samples in parallel
tyler-griggs Jan 31, 2026
b4d26a6
Merge origin/main into tyler/tinker-sampling-main
tyler-griggs Jan 31, 2026
4338348
Fix critical bugs in sampling implementation
tyler-griggs Jan 31, 2026
f763f21
Preserve exception messages in sampling error responses
tyler-griggs Jan 31, 2026
707a215
Simplify async handling in engine.py
tyler-griggs Jan 31, 2026
1e0ac27
Clean up engine.py comments and log messages
tyler-griggs Jan 31, 2026
e080d59
Simplify sampling implementation
tyler-griggs Feb 1, 2026
26088c4
Remove redundant vLLM comment
tyler-griggs Feb 1, 2026
27942d9
Add persist flag to skip disk save on RL loop hot path
tyler-griggs Feb 4, 2026
2af1854
Delegate inference engine creation to skyrl-train, simplify model check
tyler-griggs Feb 4, 2026
900d317
Address remaining review feedback
tyler-griggs Feb 4, 2026
746680e
Merge origin/main into tyler/tinker-sampling-main
tyler-griggs Feb 4, 2026
c84a7a4
Add importance_sampling loss to PolicyLossRegistry for Tinker API
Feb 5, 2026
e867686
Fix: Initialize weight sync state in Tinker SkyRL backend
Feb 5, 2026
fc0c467
Clean up comments in skyrl_train.py and add project docs
Feb 6, 2026
b071883
Add quickstart guide for running tinker-cookbook on SkyRL
Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions claude/project-summary.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SkyRL Tinker Integration - Project Summary

**Last Updated:** 2026-02-06
**Branch:** `tyler/tinker-sampling-main` (PR #999)
**Status:** Ready for Merge

---

## Completed Work

### PR #999: Tinker SkyRL Backend Sampling Support

**Key commits:**
1. Initial sampling implementation with logprobs support
2. `save_weights_for_sampler()` with ephemeral mode (persist=False)
3. Checkpoint save/load functionality
4. `importance_sampling` loss added to PolicyLossRegistry
5. `init_weight_sync_state()` fix for Tinker API flow

### Verified Functionality

| Feature | Status | Notes |
|---------|--------|-------|
| Tinker API server startup | Done | With SkyRL backend on 4xL4 GPUs |
| Model creation (LoRA) | Done | `create_lora_training_client()` |
| Sampling with logprobs | Done | Response logprobs returned correctly |
| Weight sync to inference | Done | `save_weights_for_sampler()` works |
| `forward_backward()` | Done | With importance_sampling loss |
| `optim_step()` | Done | Learning rate applied |
| Checkpoint save | Done | `tinker://model_id/weights/N` format |
| rl_loop.py end-to-end | Done | 9 batches completed successfully |

### Test Results (2026-02-06)

```
rl_loop.py with Qwen/Qwen3-0.6B:
- Batches completed: 9 (stopped due to disk space, not code error)
- Checkpoint saved: tinker://model_5987ddb1/weights/000005
- Metrics logged: /tmp/tinker-rl-test/metrics.jsonl
- Average batch time: ~18 seconds
```

---

## Key Files Modified

1. **skyrl-train/skyrl_train/utils/ppo_utils.py**
- Added `IMPORTANCE_SAMPLING` to `PolicyLossType` enum
- Implemented `importance_sampling_loss()` function
- Registered in `PolicyLossRegistry.repopulate_registry()`

2. **skyrl-tx/tx/tinker/backends/skyrl_train.py**
- Added `init_weight_sync_state()` call after `build_models()`
- This initializes `_weight_transfer_sender` required for weight sync

---

## Architecture Notes

### Tinker API Flow (SkyRL Backend)
```
ServiceClient.create_lora_training_client()
-> SkyRLTrainBackend.create_model()
-> RayPPOTrainer(...)
-> trainer.build_models(PolicyWorker, ...)
-> trainer.init_weight_sync_state() <- CRITICAL: must be called!

TrainingClient.save_weights_for_sampler()
-> backend.save_weights_for_sampler(persist=False)
-> dispatch.broadcast_to_inference_engines()
-> worker._weight_transfer_sender.send_chunks()
```

### Loss Function Implementation
```python
# importance_sampling matches Tinker docs:
# https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling
prob_ratio = torch.exp(log_probs - old_log_probs)
loss = -(prob_ratio * advantages).sum()
```

---

## Known Issues / TODOs

### High Priority
- [ ] **Disk space management**: Checkpoints fill /tmp quickly on multi-batch runs
- [ ] Clean up tinker.db between test runs: `rm skyrl-tx/tx/tinker/tinker.db`

### Medium Priority
- [ ] **Config management**: `backend_config` params don't fully propagate to SkyRL config
- [ ] **Prompt logprobs**: Not yet implemented (warning logged, not blocking)
- [ ] Review pcmoritz feedback on hardcoded model workaround

### Low Priority
- [ ] Add explicit tests for importance_sampling loss in test suite
- [ ] Document Tinker + SkyRL setup in quickstart guide
- [ ] Consider adding PPO loss to PolicyLossRegistry (currently only in JAX backend)

---

## How to Test

### Start Server
```bash
cd ~/tgriggs/SkyRL/skyrl-tx
rm -f tx/tinker/tinker.db # Clean database

uv run --extra skyrl_train --extra tinker -m tx.tinker.api \
--base-model "Qwen/Qwen3-0.6B" \
--backend skyrl_train
```

### Run RL Loop Test
```bash
cd ~/tinker-cookbook
TINKER_API_KEY=tml-test uv run --with tinker --with datasets --with torch \
python -m tinker_cookbook.recipes.rl_loop \
base_url=http://localhost:8000 \
model_name="Qwen/Qwen3-0.6B" \
batch_size=8 \
group_size=4 \
lora_rank=32 \
max_tokens=128 \
save_every=5 \
log_path="/tmp/tinker-rl-test"
```

---

## References

- PR #999: https://github.com/NovaSky-AI/SkyRL/pull/999
- Tinker Loss Docs: https://tinker-docs.thinkingmachines.ai/losses
- RL Loop Recipe: ~/tinker-cookbook/tinker_cookbook/recipes/rl_loop.py
- Detailed Plan: ~/tgriggs/SkyRL/claude/rl-loop-verify.md
201 changes: 201 additions & 0 deletions claude/rl-loop-verify.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# RL Loop Verification Plan - Running tinker-cookbook/rl_loop.py on SkyRL

**Date:** 2026-02-01
**Goal:** Run `~/tinker-cookbook/tinker_cookbook/recipes/rl_loop.py` with zero code changes on SkyRL backend
**Status:** ✅ Server Running - Ready for Component Tests

---

## ✅ PROGRESS UPDATE (Hack Approach - Completed!)

**What we did:** Copied entire Tinker codebase from skyrl-tx to skyrl-train (~45 min)

### Completed Steps:
1. ✅ Copied `tx/tinker/` → `skyrl_train/tinker/` + `tx/utils/` → `skyrl_train/tx_utils/`
2. ✅ Updated all imports: `tx.tinker` → `skyrl_train.tinker`
3. ✅ Deleted JAX code: `jax.py`, `loss_fns.py`, removed JAX references
4. ✅ Fixed engine subprocess: `--extra vllm -m skyrl_train.tinker.engine`
5. ✅ Added dependencies: fastapi, sqlmodel, sqlalchemy, aiosqlite, cloudpathlib, httpx
6. ✅ Server running on http://0.0.0.0:8000 with Qwen3-0.6B

**Result:** Zero tx dependencies, no JAX conflicts, API server fully operational!

---

## Executive Summary

All required functionality for rl_loop.py is already implemented on branch `tyler/tinker-sampling-main`:
- ✅ Sampling with response logprobs
- ✅ save_weights_for_sampler() for weight sync
- ✅ forward_backward(loss_fn="importance_sampling")
- ✅ Checkpoint save/load/resume

**Key Finding:** rl_loop.py does NOT require prompt logprobs (only response logprobs at line 188), removing one major TODO from critical path.

**Current Branch:** `tyler/tinker-sampling-main` ✅

---

## Phase 1: Verification (Est: 2-4 hours)

### Commands to Run

#### 1. Start SkyRL Tinker API Server ✅ DONE!

**Command used (from skyrl-train):**
```bash
cd ~/SkyRL/skyrl-train

uv run --extra vllm python -m skyrl_train.tinker.api \
--base-model "Qwen/Qwen3-0.6B" \
--backend skyrl_train
```

**What to look for in logs:**
- ✅ "Created 1 inference engines for sampling" (NOT "SFT-only mode")
- ✅ "Application startup complete"
- ✅ "Uvicorn running on http://0.0.0.0:8000"

**Health check (run in separate terminal):**
```bash
curl http://localhost:8000/health
```

#### 2. Component Tests (Claude will create and run these)

Once server is running, Claude will create and run:
- Test 1: Sampling with response logprobs (num_samples=2)
- Test 2: Importance sampling loss function
- Test 3: Checkpoint save/load

#### 3. Run rl_loop.py End-to-End

**Quick 2-batch smoke test:**
```bash
cd ~/tinker-cookbook

uv run --with tinker --with datasets --with torch python -m tinker_cookbook.recipes.rl_loop \
base_url=http://localhost:8000 \
model_name="Qwen/Qwen2.5-0.5B-Instruct" \
batch_size=8 \
group_size=4 \
lora_rank=32 \
max_tokens=128 \
log_path="/tmp/tinker-rl-test"
```

**Success criteria:**
- Script completes 2+ batches without errors
- Checkpoints saved to /tmp/tinker-rl-test/
- Metrics logged with reward values
- No Python exceptions

---

## Common Issues & Solutions

| Issue | Cause | Solution |
|-------|-------|----------|
| `NotImplementedError: Sampling not supported` | Wrong branch or server config | Verify on tyler/tinker-sampling-main |
| `KeyError: 'logprobs'` at rl_loop.py:188 | Sampling not returning logprobs | Check skyrl_train.py:240-280 |
| `Unknown loss function: importance_sampling` | Loss not registered | Check tx/tinker/loss_fns.py:42 |
| OOM during sampling | Too many samples | Reduce batch_size=4, group_size=2 |
| Server shows "SFT-only mode" | num_inference_engines=0 | Check backend-config has num_inference_engines=1 |

---

## Critical Files

1. **~/SkyRL/skyrl-tx/tx/tinker/backends/skyrl_train.py** (509 lines)
- Lines 207-300: sample() with logprobs
- Lines 301-350: save_weights_for_sampler()
- Lines 400-509: checkpoint methods

2. **~/SkyRL/skyrl-train/skyrl_train/workers/worker_dispatch.py**
- Lines 157-202: forward_backward(loss_fn=...)
- Lines 318-338: save_weights_for_sampler()

3. **~/SkyRL/skyrl-tx/tx/tinker/loss_fns.py**
- Line 42: "importance_sampling" in LOSS_FUNCTION_MAP

4. **~/tinker-cookbook/tinker_cookbook/recipes/rl_loop.py**
- Line 148-154: save_weights_for_sampler()
- Line 168-172: sample(num_samples=group_size)
- Line 188: Assert response logprobs not None
- Line 235: forward_backward(loss_fn="importance_sampling")

---

## Next Steps After Verification

### If Successful ✅
1. Add round-trip checkpoint tests
2. Implement config management (backend_config → SkyRL config)
3. Document setup in quickstart guide
4. Clean up and merge to main

### If Failed ❌
1. Check server logs in /tmp/skyrl-tinker-server.log
2. Verify branch with `wc -l skyrl_train.py` (should be ~509, not 220)
3. Test components individually
4. Add debug logging with `export SKYRL_LOG_LEVEL=DEBUG`

---

## What rl_loop.py Actually Requires

From detailed code analysis:

**Required APIs:**
1. `ServiceClient.create_lora_training_client(base_model, rank)` ✅
2. `TrainingClient.save_weights_for_sampler(name, ttl_seconds)` ✅
3. `ServiceClient.create_sampling_client(model_path)` ✅
4. `SamplingClient.sample(prompt, num_samples, sampling_params)` ✅
5. `TrainingClient.forward_backward(datums, loss_fn="importance_sampling")` ✅
6. `TrainingClient.optim_step(adam_params)` ✅
7. `ServiceClient.create_training_client_from_state_with_optimizer(path)` ✅

**Required Data Types:**
- `Datum(model_input, loss_fn_inputs)` ✅
- `loss_fn_inputs: {target_tokens, logprobs, advantages}` ✅
- `SampleOutput.sequences[].logprobs` (response only, NOT prompt) ✅
- `TensorData.from_torch()` serialization ✅

**NOT Required:**
- ❌ Prompt logprobs (rl_loop.py line 188 only asserts response logprobs)
- ❌ Multi-checkpoint sampling (always uses latest from line 148)

---

## Fallback Options

If verification reveals insurmountable issues:

**Option 1: Use JAX Backend**
```bash
uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model "Qwen/Qwen2.5-0.5B-Instruct" \
--backend jax
```

**Option 2: Debug specific components** based on error messages

**Option 3: Use external Tinker API** (e.g., Thinking Machines)

---

## Timeline

- **Phase 1 (Verification):** 2-4 hours ← WE ARE HERE
- **Phase 2 (Debug, if needed):** 4-8 hours
- **Phase 3 (Hardening):** 8-12 hours
- **Total:** 10-24 hours depending on verification outcome

---

## References

- Project Summary: ~/claude-docs/skyrl/project-summary.md
- Full Plan: ~/.claude/plans/tidy-coalescing-otter.md
- RL Loop Source: ~/tinker-cookbook/tinker_cookbook/recipes/rl_loop.py
- SkyRL Backend: ~/SkyRL/skyrl-tx/tx/tinker/backends/skyrl_train.py
Loading
Loading