diff --git a/.agents/skills/train-rl/SKILL.md b/.agents/skills/train-rl/SKILL.md new file mode 100644 index 000000000..23f2b8835 --- /dev/null +++ b/.agents/skills/train-rl/SKILL.md @@ -0,0 +1,386 @@ +--- +name: train-rl +description: RL training reference for the ART framework. Use when the user asks to create, write, or help with an RL training script, reinforcement learning, GRPO, reward functions, RULER scoring, rollout functions, or anything related to RL fine-tuning. +--- + +# RL Training Wizard + +You are guiding the user through setting up Reinforcement Learning (RL) training for a language model using the ART framework. Act as an interactive wizard: ask questions, validate inputs, and generate a complete runnable script. + +**Important**: Ask ONE question at a time. Wait for the user's response before asking the next question. Never bundle multiple questions into a single message. + +**Adaptability note**: Some steps reference tools like AskUserQuestion, Glob, or Bash. If you don't have access to these tools, simply ask the user the same questions as plain text and skip any steps that require running code (e.g., file search, dataset validation, hyperparameter computation). Do NOT fabricate results — never pretend you ran a tool or searched for files when you didn't. + +## Step 1: Single-turn or Multi-turn + +Ask the user using AskUserQuestion: + +1. **Single-turn** — The agent responds to a prompt once. A reward is assigned based on that single response (e.g., solve a math problem, classify text, answer a question). +2. **Multi-turn** — The agent interacts over multiple turns with an environment, tools, or a game. A reward is assigned at the end of all turns (e.g., play a board game, use tools to complete a task, navigate a conversation). + +## Step 2: Describe the Task + +Ask the user to describe what the agent needs to do. Tell them you will help create a draft of the rollout function and environment, but they will likely need to edit it once the script is generated. + +Gather: +- **Task description** — What does the agent need to accomplish? +- **Scenarios/inputs** — How are training inputs generated or provided? (e.g., a list of problems, a game generator, a dataset of tasks) +- **System prompt** (optional) — Any system-level instructions for the agent + +For **multi-turn** scenarios, also ask: +- **Does the agent use tool calling?** — If yes, gather tool names, descriptions, parameter schemas (OpenAI function calling format), and how tool calls are executed (local function, API, MCP server, etc.) +- **How does the environment work?** — What observations does the agent receive? What actions can it take? How does a turn work? +- **When does an episode end?** — Win/loss conditions, turn limits, max tool calls, etc. (default max turns: 10) + +Help the user flesh out incomplete descriptions. Offer to write helper functions (game logic, tool execution, scenario generators) as part of the final script. + +## Step 3: Reward Method + +Ask the user using AskUserQuestion: + +1. **Programmatic reward** — You have a ground truth or scoring function to compute the reward (e.g., check correctness against an answer, game win/loss, composite score) +2. **RULER (LLM-as-judge)** — An LLM judge scores and compares the trajectories. No manual reward function needed. Requires an OpenAI API key (`OPENAI_API_KEY` env var). + +If they choose **programmatic reward**, help them design a reward function. Common patterns: +- **Binary**: 1 for correct, 0 for incorrect +- **Accuracy**: fraction of correct sub-answers (0.0 to 1.0) +- **Game outcome**: 1 for win, 0.5 for draw, 0 for loss, -1 for invalid move +- **Scaled score**: logarithmic or normalized continuous score +- **Composite**: weighted combination of multiple signals + +The reward must be a float assigned to `trajectory.reward`. Additional signals can go in `trajectory.metrics` for W&B logging. **Important: `metrics` values must be numeric (`float`, `int`) or `bool` — strings are not allowed and will cause a Pydantic validation error.** + +If they choose **RULER**, ask for: +- **Judge model**: Recommend `openai/o4-mini` (default) or `openai/o3` for higher quality + +## Step 4: Gather Base Parameters + +Do NOT ask the user to review or confirm their answers after collecting them — just proceed to the next step. + +- **Base model**: Recommend ONLY these models: + - `OpenPipe/Qwen3-14B-Instruct` + - `Qwen/Qwen3-30B-A3B-Instruct-2507` + - `meta-llama/Llama-3.1-8B-Instruct` +- **Project name**: A name for this training project (default: `rl-project`) +- **Run name**: A static, descriptive name (e.g., `math-solver-001`, `game-agent-001`). Ask the user for a meaningful name. Do NOT generate random names. + +## Step 5: Gather Hyperparameters + +Present these defaults to the user, then ask using AskUserQuestion: +- **Use defaults (Recommended)** — show all values in the description +- **Customize** — adjust individual hyperparameters + +Default values: +- **Learning rate**: `1e-5` +- **Number of training steps**: `50` +- **Rollouts per group**: `8` (number of trajectories per scenario per step; more = better advantage estimation but slower). For RULER, default to `16`. +- **Groups per step**: `1` (number of different scenarios per training step) + +If they choose "Customize", ask which parameters to change. + +## Step 6: Generate the Training Script + +Write a complete, runnable Python script by combining the appropriate **rollout pattern** (from Step 1/2) with the appropriate **reward method** (from Step 3) and the **training loop**. + +Every script MUST: +- Call `await backend.close()` at the end so the process doesn't hang +- Print post-training info and usage examples (see shared block below) + +### Post-training block (append to ALL scripts before `backend.close()`): +```python + # --- Training complete --- + step = await model.get_step() + inference_name = model.get_inference_name() + client = model.openai_client() + + print("\n" + "=" * 60) + print("RL TRAINING COMPLETE") + print("=" * 60) + print(f" Model: {inference_name}") + print(f" Base model: ") + print(f" Training step: {step}") + print(f" Inference URL: {client.base_url}") + print("=" * 60) + + print("\n--- Python usage (openai SDK) ---\n") + print(f'''\ +from openai import OpenAI + +client = OpenAI( + base_url="{client.base_url}", + api_key="not-needed", +) + +response = client.chat.completions.create( + model="{inference_name}", + messages=[ + {{"role": "user", "content": "Your prompt here"}}, + ], +) +print(response.choices[0].message.content) +''') + + print("--- curl usage ---\n") + print(f'''\ +curl {client.base_url}chat/completions \\ + -H "Content-Type: application/json" \\ + -d '{{ + "model": "{inference_name}", + "messages": [ + {{"role": "user", "content": "Your prompt here"}} + ] + }}' +''') + + await backend.close() +``` + +### Rollout pattern: Single-turn +```python +async def rollout(model: art.Model, scenario: dict) -> art.Trajectory: + client = AsyncOpenAI( + base_url=model.inference_base_url, + api_key=model.inference_api_key, + ) + messages = [ + # {"role": "system", "content": ""}, + {"role": "user", "content": scenario["prompt"]}, + ] + response = await client.chat.completions.create( + model=model.get_inference_name(), + messages=messages, + temperature=0.7, + ) + choice = response.choices[0] + + # --- Compute reward (if programmatic) --- + reward = # e.g., 1.0 if correct else 0.0 + + return art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward, + metrics={"acc": reward}, + ) +``` + +### Rollout pattern: Multi-turn (environment/game loop) +```python +async def rollout(model: art.Model, scenario) -> art.Trajectory: + client = AsyncOpenAI( + base_url=model.inference_base_url, + api_key=model.inference_api_key, + ) + game = create_game() + traj = art.Trajectory( + messages_and_choices=[ + {"role": "system", "content": ""}, + ], + reward=0.0, + ) + + while not is_finished(game): + traj.messages_and_choices.append( + {"role": "user", "content": render_observation(game)} + ) + response = await client.chat.completions.create( + model=model.get_inference_name(), + messages=traj.messages(), + temperature=0.7, + max_completion_tokens=256, + ) + choice = response.choices[0] + traj.messages_and_choices.append(choice) + try: + apply_action(game, choice.message.content) + except ValueError: + traj.reward = -1.0 + return traj + + traj.reward = compute_reward(game) + return traj +``` + +### Rollout pattern: Multi-turn with tool calling +```python +async def rollout(model: art.Model, scenario: dict) -> art.Trajectory: + client = AsyncOpenAI( + base_url=model.inference_base_url, + api_key=model.inference_api_key, + ) + MAX_TURNS = + traj = art.Trajectory( + messages_and_choices=[ + # {"role": "system", "content": ""}, + {"role": "user", "content": scenario["task"]}, + ], + tools=tools, + reward=0.0, + ) + + for turn in range(MAX_TURNS): + response = await client.chat.completions.create( + model=model.get_inference_name(), + messages=traj.messages(), + tools=tools, + temperature=0.7, + ) + choice = response.choices[0] + traj.messages_and_choices.append(choice) + + if choice.message.tool_calls: + for tc in choice.message.tool_calls: + args = json.loads(tc.function.arguments) + result = execute_tool(tc.function.name, args) + traj.messages_and_choices.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": str(result), + }) + else: + break # Agent finished (no more tool calls) + + # --- Compute reward (if programmatic) --- + traj.reward = + return traj +``` + +### Reward method: RULER addition +When using RULER, the rollout function should set `reward=0.0` (RULER fills it in). Add this scoring block inside the training loop, after `gather_trajectory_groups` and before `model.train`: + +```python +from art.rewards import ruler_score_group + + # Score with RULER (LLM judge assigns relative rewards 0-1) + judged_groups = [] + for group in finished_groups: + judged = await ruler_score_group( + group, + judge_model=JUDGE_MODEL, + debug=True, + ) + judged_groups.append(judged) + finished_groups = judged_groups +``` + +### Training loop (shared by all patterns): +```python +"""RL training script generated by /train-rl wizard.""" +import asyncio +import json +from openai import AsyncOpenAI +import art +from art.local import LocalBackend + +# --- Scenarios --- +scenarios = [ + # Define or load your training scenarios here. +] + +# --- Rollout function --- +# (insert the appropriate rollout pattern here) + +# --- Training loop --- +async def main(): + backend = LocalBackend() + model = art.TrainableModel( + name="", + project="", + base_model="", + _internal_config=art.dev.InternalModelConfig( + engine_args={"gpu_memory_utilization": 0.7}, + ), + ) + await model.register(backend) + + NUM_STEPS = + ROLLOUTS_PER_GROUP = + GROUPS_PER_STEP = + + for step in range(await model.get_step(), NUM_STEPS): + groups = [ + art.TrajectoryGroup( + rollout(model, scenarios[ + (step * GROUPS_PER_STEP + i) % len(scenarios) + ]) + for _ in range(ROLLOUTS_PER_GROUP) + ) + for i in range(GROUPS_PER_STEP) + ] + finished_groups = await art.gather_trajectory_groups( + groups, pbar_desc=f"step {step}" + ) + + # (insert RULER scoring block here if using LLM-as-judge) + + avg_reward = sum( + t.reward for g in finished_groups for t in g.trajectories + ) / max(1, sum(len(g.trajectories) for g in finished_groups)) + print(f"Step {step}: avg_reward={avg_reward:.3f}") + + await model.delete_checkpoints() + await model.train( + finished_groups, + config=art.TrainConfig(learning_rate=), + ) + + # ... post-training block + backend.close() ... + +if __name__ == "__main__": + asyncio.run(main()) +``` + +### Alternative loop: Dataset-driven with iterate_dataset +When the user has a fixed list of training scenarios and wants epoch-based iteration, use `iterate_dataset` instead of the manual step loop. This can be combined with any rollout pattern and reward method. + +```python +from art.utils import iterate_dataset + + # Replace the manual for-loop with: + training_iterator = iterate_dataset( + scenarios, + groups_per_step=, + num_epochs=, + initial_step=await model.get_step(), + ) + + for batch in training_iterator: + groups = [ + art.TrajectoryGroup( + rollout(model, item) for _ in range(ROLLOUTS_PER_GROUP) + ) + for item in batch.items + ] + finished_groups = await art.gather_trajectory_groups( + groups, pbar_desc=f"epoch {batch.epoch} step {batch.step}" + ) + + # (insert RULER scoring block here if using LLM-as-judge) + + avg_reward = sum( + t.reward for g in finished_groups for t in g.trajectories + ) / max(1, sum(len(g.trajectories) for g in finished_groups)) + print(f"Step {batch.step} (epoch {batch.epoch}): avg_reward={avg_reward:.3f}") + + await model.delete_checkpoints() + await model.train( + finished_groups, + config=art.TrainConfig(learning_rate=), + ) +``` + +## Step 7: Write and Offer to Run + +1. Write the script to a file (suggest `rl_train.py`) +2. Ask the user if they want to run it now with `uv run python ` +3. If yes, run it **directly using the Bash tool** (do NOT delegate to a Task subagent) so training logs stream live to the user. Use a **2-minute timeout**. If it times out, check progress and decide whether to continue. +4. **GPU memory errors**: If training fails with OOM, lower `gpu_memory_utilization` in the existing `_internal_config` (e.g. from `0.7` to `0.5`). +5. **Stale GPU memory**: If available GPU memory looks too small, previous training runs may still be occupying memory. Before retrying, run `nvidia-smi` to check, and if needed kill leftover processes with `kill ` to free memory. + +## Important Notes + +- LocalBackend requires a GPU. +- RL uses **GRPO** (Group Relative Policy Optimization) under the hood. It needs multiple trajectories per scenario (a `TrajectoryGroup`) to compute relative advantages. More rollouts per group = better advantage estimation. +- **RULER** eliminates the need for manual reward engineering by using an LLM judge to compare trajectories within a group. It requires an OpenAI API key (`OPENAI_API_KEY` env var). +- The `@art.retry` decorator can wrap rollout functions to handle transient errors: `@art.retry(exceptions=(openai.LengthFinishReasonError,))`. +- **Validation**: To log validation metrics without training, use `await model.log(val_groups)` or `await model.log(val_groups, split="val")`. +- **Resuming**: All patterns use `await model.get_step()` as the loop start, so training resumes from the last checkpoint automatically. diff --git a/.agents/skills/train-sft/SKILL.md b/.agents/skills/train-sft/SKILL.md new file mode 100644 index 000000000..2bdcc5739 --- /dev/null +++ b/.agents/skills/train-sft/SKILL.md @@ -0,0 +1,298 @@ +--- +name: train-sft +description: SFT training reference for the ART framework. Use when the user asks to create, write, or help with an SFT training script, fine-tune a model, train from a JSONL dataset, do distillation, or anything related to supervised fine-tuning. +--- + +# SFT Training Wizard + +You are guiding the user through setting up Supervised Fine-Tuning (SFT) for a language model using the ART framework. Act as an interactive wizard: ask questions, validate inputs, and generate a complete runnable script. + +**Important**: Ask ONE question at a time. Wait for the user's response before asking the next question. Never bundle multiple questions into a single message. + +**Adaptability note**: Some steps reference tools like AskUserQuestion, Glob, or Bash. If you don't have access to these tools, simply ask the user the same questions as plain text and skip any steps that require running code (e.g., file search, dataset validation, hyperparameter computation). Do NOT fabricate results — never pretend you ran a tool or searched for files when you didn't. + +## Step 1: Determine Training Scenario + +Ask the user ONE question at a time. Wait for their response before moving to the next question. + +**Training scenario:** +1. **Train from a JSONL file** — They have a dataset file with chat-formatted examples +2. **Distillation** — They want to train a smaller model using outputs from a larger teacher model + +## Step 2: Determine Backend + +**Backend:** +1. **ServerlessBackend (Recommended)** — Train on remote managed GPUs. No local GPU needed, production-ready inference endpoint. +2. **LocalBackend** — Train on your local GPU. Full control, fast iteration. + +## Step 3: Select and Validate Dataset (JSONL scenario) + +**IMPORTANT**: Do NOT assume a dataset. Do NOT make up or hallucinate file paths. Never pretend you searched for files if you didn't actually run a search tool. + +If you have access to file system tools (Glob) and can actually execute them, search for `.jsonl` files using Glob (`**/*.jsonl`). Present real results as options. Always include "Provide my own file path" as the last option. + +Otherwise, ask the user: "What is the path to your JSONL training file?" — nothing more. + +Once the user has provided a file path, validate it if you can run code using the script below. If you cannot run code, skip validation and move on. + +```python +import json, sys +ROLES = {"system", "user", "assistant", "developer", "tool", "function"} +errors = [] +for i, line in enumerate(open(sys.argv[1]), 1): + try: + r = json.loads(line) + msgs = r.get("input", r).get("messages", []) + assert isinstance(msgs, list) and msgs, "no messages" + for j, m in enumerate(msgs): + assert m.get("role") in ROLES, f"messages[{j}]: invalid role {m.get('role')!r}" + assert m.get("content") or m.get("function_call") or m.get("tool_calls"), f"messages[{j}]: no content" + if "input" not in r: + assert msgs[-1]["role"] == "assistant", "last message must be from assistant" + tools = r.get("tools") + if tools is not None: + assert isinstance(tools, list), "tools must be a list" + except Exception as e: + errors.append(f" Line {i}: {e}") +print(f"{len(errors)} error(s):\n" + "\n".join(errors) if errors else f"Valid! {i} rows") +sys.exit(1 if errors else 0) +``` + +The JSONL format supports these fields per row: +- **`messages`** (required): List of chat messages +- **`tools`** (optional): List of tool/function definitions for tool-call training +- **`response_format`** (optional): Structured output schema (not used during training, but useful as metadata) + +Report the row count and validation result to the user. Do NOT read the whole dataset file. Do NOT name the dataset. If the format is wrong, help them fix it or convert their data. + +## Step 4: Gather Base Parameters + +Do NOT ask the user to review or confirm their answers after collecting them — just proceed to the next step. + +- **Base model**: Recommend ONLY these models: + - `OpenPipe/Qwen3-14B-Instruct` + - `Qwen/Qwen3-30B-A3B-Instruct-2507` + - `meta-llama/Llama-3.1-8B-Instruct` +- **Project name**: A name for this training project (default: `sft-project`) +- **Run name**: A static, descriptive name (e.g., `agent-001`, `pii-redactor-001`, `math-tutor-001`). Ask the user for a meaningful name. Do NOT generate random names. + +For **distillation** also ask: +- **Teacher model**: The larger model to distill from (e.g., an OpenRouter model) +- **Teacher API base URL and key**: If using a third-party provider +- **Prompts**: What prompts to send to the teacher model + +## Step 5: Gather Hyperparameters + +This step only applies if you can run code AND know the row count from validation. If you cannot run code, skip this step entirely — do NOT make up or guess hyperparameter values. The `train_sft_from_file` function has sensible built-in defaults. + +Run this Python snippet via Bash to compute defaults (replace `NUM_ROWS` with the actual row count). Do NOT show any formulas or calculation steps to the user — only show the final values. + +```python +import math, sys +n = int(sys.argv[1]) +epochs = max(1, min(10, round(10000 / n))) +batch_size = 2 +total_steps = math.ceil(n * epochs / batch_size) +steps_per_epoch = math.ceil(n / batch_size) +warmup_steps = max(10, min(1000, round(steps_per_epoch * 0.05))) +warmup_ratio = round(warmup_steps / total_steps, 4) +print(f"epochs={epochs} batch_size={batch_size} lr=2e-4 schedule=linear warmup_ratio={warmup_ratio}") +``` + +Present the output values to the user, then ask: +- **Use defaults (Recommended)** — show all values in the description +- **Customize** — adjust individual hyperparameters + +If they choose "Customize", ask which parameters to change. + +### For distillation: +Use the same defaults computation as JSONL (replace `NUM_ROWS` with the number of trajectories). `create_sft_dataset_iterator` handles the LR schedule automatically. + +## Step 6: Generate the Training Script + +Write a complete, runnable Python script. Use the patterns below. Every script MUST: +- Call `await backend.close()` at the end so the process doesn't hang +- Print post-training info and usage examples (see shared block below) + +### Post-training block (append to ALL scripts before `backend.close()`): +```python + # --- Training complete --- + step = await model.get_step() + inference_name = model.get_inference_name() + client = model.openai_client() + + print("\n" + "=" * 60) + print("SFT TRAINING COMPLETE") + print("=" * 60) + print(f" Model: {inference_name}") + print(f" Base model: ") + print(f" Training step: {step}") + print(f" Inference URL: {client.base_url}") + print(f" W&B run: https://wandb.ai///runs/") + print("=" * 60) + + print("\n--- Python usage (openai SDK) ---\n") + print(f'''\ +from openai import OpenAI + +client = OpenAI( + base_url="{client.base_url}", + api_key="not-needed", +) + +response = client.chat.completions.create( + model="{inference_name}", + messages=[ + {{"role": "user", "content": "Your prompt here"}}, + ], +) +print(response.choices[0].message.content) +''') + + print("--- curl usage ---\n") + print(f'''\ +curl {client.base_url}chat/completions \\ + -H "Content-Type: application/json" \\ + -d '{{ + "model": "{inference_name}", + "messages": [ + {{"role": "user", "content": "Your prompt here"}} + ] + }}' +''') + + await backend.close() +``` + +### Backend setup + +Use the appropriate backend based on the user's choice: + +**LocalBackend:** +```python +from art.local import LocalBackend + +backend = LocalBackend() +model = art.TrainableModel( + name="", + project="", + base_model="", + _internal_config=art.dev.InternalModelConfig( + engine_args={"gpu_memory_utilization": 0.7}, + ), +) +await model.register(backend) +``` + +**ServerlessBackend:** +```python +from art.serverless.backend import ServerlessBackend + +backend = ServerlessBackend() # uses WANDB_API_KEY env var +model = art.TrainableModel( + name="", + project="", + base_model="", +) +await model.register(backend) +``` + +Note: `_internal_config` with `gpu_memory_utilization` is only used with LocalBackend. Do NOT include it for ServerlessBackend. + +### JSONL file training pattern: + +If hyperparameters were computed in Step 5, pass them explicitly. If Step 5 was skipped, omit them — `train_sft_from_file` has sensible defaults. + +```python +"""SFT training script generated by /train-sft wizard.""" +import asyncio +import art + +from art.utils.sft import train_sft_from_file + +async def main(): + + + await train_sft_from_file( + model=model, + file_path="", + # Only include these if hyperparameters were computed: + # epochs=, + # batch_size=, + # peak_lr=, + # schedule_type="", + # warmup_ratio=, + verbose=True, + ) + + # ... post-training block + backend.close() ... + +if __name__ == "__main__": + asyncio.run(main()) +``` + +### Distillation pattern: +```python +"""Distillation SFT script generated by /train-sft wizard.""" +import asyncio, os +from dotenv import load_dotenv +from openai import AsyncOpenAI +import art + +from art.utils.sft import create_sft_dataset_iterator + +load_dotenv() + +async def main(): + teacher_client = AsyncOpenAI( + api_key=os.environ[""], + base_url="", + ) + prompts = ["", ""] + + trajectories = [] + for prompt in prompts: + completion = await teacher_client.chat.completions.create( + model="", + messages=[{"role": "user", "content": prompt}], + ) + trajectories.append( + art.Trajectory( + messages_and_choices=[ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": completion.choices[0].message.content}, + ], + tools=, + ) + ) + + + + for chunk in create_sft_dataset_iterator( + trajectories, + epochs=, + batch_size=, + peak_lr=, + schedule_type="", + warmup_ratio=, + ): + await model.train_sft(chunk.trajectories, chunk.config, verbose=True) + + # ... post-training block + backend.close() ... + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Step 7: Write and Offer to Run + +1. Write the script to a file (suggest `sft_train.py`) +2. Ask the user if they want to run it now with `uv run python ` +3. If yes, run it **directly using the Bash tool** (do NOT delegate to a Task subagent) so training logs stream live to the user. Use a **2-minute timeout**. If it times out, check progress and decide whether to continue. +4. **LocalBackend only — GPU memory errors**: If training fails with OOM, lower `gpu_memory_utilization` in the existing `_internal_config` (e.g. from `0.7` to `0.5`). +5. **LocalBackend only — Stale GPU memory**: If available GPU memory looks too small, previous training runs may still be occupying memory. Before retrying, run `nvidia-smi` to check, and if needed kill leftover processes with `kill ` to free memory. + +## Important Notes + +- LocalBackend requires a GPU. +- ServerlessBackend requires a `WANDB_API_KEY` environment variable. diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 000000000..2b7a412b8 --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +../.agents/skills \ No newline at end of file diff --git a/.gitignore b/.gitignore index 748d300f7..bc0764abb 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,7 @@ replays/ trajectories/ .DS_Store .local/ -.claude/ +.claude/settings.local.json .vscode/ .ruff_cache/ !/src/art/wandb/ diff --git a/dev/sft/dataset.jsonl b/dev/sft/dataset.jsonl new file mode 100644 index 000000000..87ed654b2 --- /dev/null +++ b/dev/sft/dataset.jsonl @@ -0,0 +1,10 @@ +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a medical_record. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"**Confidential Medical Record**\n\n**Patient Information:**\n\n- **Name:** Victor Roberts\n- **Date of Birth:** 1982-04-07\n- **Age:** 28\n- **Gender:** Male\n- **Personal ID:** 260-56-6683\n\n---\n\n**Medical History Overview:**\n\n**Current Medical Condition:**\n- **Diagnosis:** Aneurysm \n- **Date of Diagnosis:** 2010-09-15 \n- **Presenting Symptoms:** Severe headaches, vision impairment, dizziness.\n\n**Treatment Plan:**\n- **Medications Prescribed:**\n - Metoprolol 50 mg, once daily\n - Aspirin 81 mg, once daily\n- **Scheduled Follow-ups:** Every 3 months with Neurology Department\n\n**Previous Medical History:**\n- **2001:** Tonsillectomy\n- **1996:** Admitted for severe bout of influenza\n- **Allergies:** Penicillin\n\n---\n\n**Lifestyle and Social History:**\n\n- **Occupation:** Structural Engineer\n- **Smoking Status:** Non-smoker\n- **Alcohol Consumption:** Occasional, primarily socially\n- **Exercise Habits:** Engages in light exercises twice a week\n\n---\n\n**Family Medical History:**\n\n- **Father:** Hypertension\n- **Mother:** No known medical conditions\n- **Siblings:** One brother, no known medical conditions\n\n---\n\n**Emergency Contact:**\n\n- **Name:** Linda Roberts\n- **Relationship:** Sister\n- **Phone Number:** [Redacted]\n\n---\n\n**Consultations:**\n\n**Neurologist:** Dr. Emily Chen \n**Next Appointment:** 2023-11-12 at 10:00 AM\n\n**Cardiologist:** Dr. Raj Singh \n**Next Appointment:** 2023-11-20 at 2:30 PM\n\n**Notes:**\n- Continue to monitor blood pressure closely.\n- MRI scan of the brain scheduled for 2023-12-05 to evaluate treatment effectiveness.\n\n---\n\n**Doctor's Notes:**\n\n- **Reviewed by:** Dr. Natalie Brooks, MD\n- **Date of Review:** 2023-10-19\n\n**Remarks:**\n- Patient remains attentive to prescribed treatment and lifestyle recommendations. \n- Mr. Roberts reports no new symptoms or concerns since the last visit.\n- Encourage continuation of cardiovascular adherence and stress-reducing activities.\n\n**End of Record**"},{"content":"{\"fields_to_redact\":[{\"string\":\"Victor Roberts\",\"pii_type\":\"person_name\"},{\"string\":\"1982-04-07\",\"pii_type\":\"date_of_birth\"},{\"string\":\"28\",\"pii_type\":\"age\"},{\"string\":\"Male\",\"pii_type\":\"gender\"},{\"string\":\"260-56-6683\",\"pii_type\":\"personal_id\"},{\"string\":\"Aneurysm\",\"pii_type\":\"medical_condition\"},{\"string\":\"2010-09-15\",\"pii_type\":\"date\"},{\"string\":\"Penicillin\",\"pii_type\":\"medical_condition\"},{\"string\":\"Linda Roberts\",\"pii_type\":\"person_name\"},{\"string\":\"Dr. Emily Chen\",\"pii_type\":\"person_name\"},{\"string\":\"2023-11-12\",\"pii_type\":\"date\"},{\"string\":\"Dr. Raj Singh\",\"pii_type\":\"person_name\"},{\"string\":\"2023-11-20\",\"pii_type\":\"date\"},{\"string\":\"2023-12-05\",\"pii_type\":\"date\"},{\"string\":\"Dr. Natalie Brooks, MD\",\"pii_type\":\"person_name\"},{\"string\":\"2023-10-19\",\"pii_type\":\"date\"},{\"string\":\"Hypertension\",\"pii_type\":\"medical_condition\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a company_memo. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"---\n\n**Memo**\n\n**To**: All Employees \n**From**: Management \n**Date**: January 11, 2011 \n**Subject**: Update on Internal Policies and Procedures \n\nDear Team,\n\nAs we strive towards continuous excellence at Clay Group, we are implementing some updates to our internal policies and procedures, effective immediately. Please take a moment to review these changes and how they may impact your daily operations.\n\n**1. Emergency Protocol** \nThe safety and well-being of our employees are a top priority. We have updated our emergency contact procedures. In case of any urgent situations, please contact the main office line at 001-835-693-0988 with the following extension: 8706. Make sure this number is readily accessible.\n\n**2. Cybersecurity Policy** \nWith the recent rise in cybersecurity threats, it is imperative to maintain a robust security posture. Starting next month, all team members will be required to undergo additional training on the secure handling of sensitive information. Details will be communicated in due course.\n\n**3. Remote Working Policies** \nGiven the positive feedback from our previous trials, we are pleased to announce the continuation of our flexible working arrangements. Employees can choose to work remotely up to two days a week, pending manager approval and adherence to productivity benchmarks.\n\n**4. Internal Communications** \nTo enhance transparency and communication within the organization, we are introducing a new company-wide digital bulletin board. This platform will be used for announcements, team achievements, and events. More information on accessing the bulletin board will follow.\n\nYour cooperation and dedication are vital as we work together to maintain Clay Group's reputation as a leader in innovation and employee satisfaction. If you have any questions regarding these updates, do not hesitate to contact your department head.\n\nThank you for your continued hard work and commitment.\n\nBest regards,\n\n[Signature, if applicable]\n\nManagement \nClay Group\n\n---"},{"content":"{\"fields_to_redact\":[{\"string\":\"January 11, 2011\",\"pii_type\":\"date\"},{\"string\":\"Clay Group\",\"pii_type\":\"organization_name\"},{\"string\":\"001-835-693-0988\",\"pii_type\":\"phone_number\"},{\"string\":\"8706\",\"pii_type\":\"other_id\"},{\"string\":\"Clay Group\",\"pii_type\":\"organization_name\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a company_memo. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"---\n\n**MEMORANDUM**\n\n**TO:** All Staff\n\n**FROM:** Amy Singleton, Director of Human Resources\n\n**DATE:** January 29, 1993\n\n**SUBJECT:** Exciting New Developments at O'Connor-Thomas!\n\nDear Team,\n\nI am thrilled to reach out to each of you on behalf of O'Connor-Thomas, an organization that prides itself on innovation and excellence. As we continue our journey of growth and success, I would like to share some exciting updates and opportunities that await us.\n\n**Innovation and Sustainability Initiatives**\n\nFirst, I am pleased to announce that we will be launching a new series of sustainability initiatives designed to reduce our environmental footprint. As a sector leader, O'Connor-Thomas is committed to responsible innovation, and we are adamant about making our operations greener. More details will be provided in the upcoming weeks.\n\n**Employee Development Workshops**\n\nWe understand the importance of continuous learning, which is why we are introducing a series of workshops focused on skill enhancement across various departments. These workshops aim to augment your technical and professional skills, ensuring our team remains at the forefront of industry trends. Participation details will soon follow.\n\n**Building a Stronger Community**\n\nOur company believes in the power of community. This year, we will increase our involvement with local charities and community events. Volunteering tells a great story of who we are as a company and individuals. Let's make a positive impact together!\n\nPlease feel free to reach out to me if you have any questions or suggestions regarding our upcoming projects. Feedback is always appreciated as it helps us build a better and more inclusive workplace. Together, we can achieve more!\n\nLooking forward to all the great things 1993 will bring for O'Connor-Thomas and each of us individually.\n\nWarm regards,\n\n**Amy Singleton** \nDirector of Human Resources \nO’Connor-Thomas \n\n---\n\nRemember, keep pushing boundaries and stay inspired!\n\n---"},{"content":"{\"fields_to_redact\":[{\"string\":\"January 29, 1993\",\"pii_type\":\"date\"},{\"string\":\"O'Connor-Thomas\",\"pii_type\":\"organization_name\"},{\"string\":\"Amy Singleton\",\"pii_type\":\"person_name\"},{\"string\":\"Amy Singleton\",\"pii_type\":\"person_name\"},{\"string\":\"O’Connor-Thomas\",\"pii_type\":\"organization_name\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a medical_record. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"```plaintext\n CONFIDENTIAL MEDICAL RECORD\n\nPatient Name: Brenda Davila\nAge: 58\nGender: Male\nDate of Birth: 25th December 2022 Identifier Number: ZZ025922T\n\n--------------------------------------------------------------------------\nADDRESS: \n529 Tina Villages Suite 522\nLake Robert, CO 31592\n\nPHONE: +1-555-834-9771\n\n--------------------------------------------------------------------------\nMEDICAL DIAGNOSIS:\n\nCondition: Lead Poisoning\n\nBackground: \nLead poisoning is a type of metal poisoning caused by increased levels of the heavy metal lead in the body. Exposure to lead over the long term can cause harmful effects on the brain, liver, kidney, and reproductive systems, among others.\n\nPresenting Symptoms: \n- Abdominal Pain\n- Fatigue\n- Irritability\n- Memory Loss\n\nCurrent Treatment Plan:\n- Chelation Therapy: Succimer 10 mg/kg orally every 8 hours for 5 days\n- Regular Monitoring of Blood Lead Levels (BLL)\n- Dietary Adjustments: Ensure high intake of calcium, iron, and vitamin C\n- Environmental modifications to reduce lead exposure\n\nNext Review Appointment: 15th November 2023\n\n--------------------------------------------------------------------------\nNotes:\n- Due to age discrepancy, additional verification documents for identity and medical history may be required.\n- Careful consideration of residential environment as potential source of lead.\n- Emergency Contact details were not provided.\n\n---------------------------------------------------------------------------\n\nDoctor: Dr. Anthony Williamson \nLicense No: MD335912\nContact: a.williamson@lakerobertclinic.com \n\nThis document is intended for the sole use of the individual to whom it is addressed and may contain confidential and privileged information. Any unauthorized review, use, disclosure or distribution is prohibited. If you are not the intended recipient, please contact the sender by reply email and destroy all copies of the original message.\n```"},{"content":"{\"fields_to_redact\":[{\"string\":\"Brenda Davila\",\"pii_type\":\"person_name\"},{\"string\":\"58\",\"pii_type\":\"age\"},{\"string\":\"Male\",\"pii_type\":\"gender\"},{\"string\":\"25th December 2022\",\"pii_type\":\"date_of_birth\"},{\"string\":\"ZZ025922T\",\"pii_type\":\"personal_id\"},{\"string\":\"529 Tina Villages Suite 522\\nLake Robert, CO 31592\",\"pii_type\":\"street_address\"},{\"string\":\"+1-555-834-9771\",\"pii_type\":\"phone_number\"},{\"string\":\"Lead Poisoning\",\"pii_type\":\"medical_condition\"},{\"string\":\"15th November 2023\",\"pii_type\":\"date\"},{\"string\":\"Dr. Anthony Williamson\",\"pii_type\":\"person_name\"},{\"string\":\"a.williamson@lakerobertclinic.com\",\"pii_type\":\"email_address\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a personal_email. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"Subject: Long Time No See!\n\nHi Margaud-Anaïs,\n\nI hope this email finds you well. It’s been ages since we last caught up, and I've missed our conversations. I remember the last time we chatted, you mentioned that you were working on an exciting project. How's that going? \n\nBy the way, I was reminiscing about that hilarious incident during our holiday trip back in 2003. Unbelievably, January 2nd, 2003, will always be etched in my memory as the day we nearly missed that flight because we were too engrossed in spotting cute street cats. Time really flies!\n\nAnyway, I'm reaching out because I’m planning a small get-together with some old friends next month and I would love for you to join us. Let me know if you're available – it would be so nice to catch up in person and relive some of those fun memories.\n\nPlease feel free to either shoot me an email at josephestrada@example.org or give me a call at 03 45 02 61 71 when you get a chance.\n\nLooking forward to hearing from you soon!\n\nWarm regards,\n\nJoseph"},{"content":"{\"fields_to_redact\":[{\"string\":\"Margaud-Anaïs\",\"pii_type\":\"person_name\"},{\"string\":\"January 2nd, 2003\",\"pii_type\":\"date\"},{\"string\":\"josephestrada@example.org\",\"pii_type\":\"email_address\"},{\"string\":\"03 45 02 61 71\",\"pii_type\":\"phone_number\"},{\"string\":\"Joseph\",\"pii_type\":\"person_name\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a personal_email. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"Subject: Catching Up and Exciting News!\n\nHi James,\n\nI hope this email finds you well! It’s been too long since we last caught up. I wanted to take a moment to reach out and share some exciting updates with you.\n\nFirst off, can you believe it’s already been a year since we had that unforgettable trip to the Grand Canyon? Time really flies! It was truly an experience to remember.\n\nOn a personal note, I’ve embarked on a new project that I’m really passionate about. I’m working with a local non-profit organization focused on environmental conservation here in Seville. It’s been incredibly rewarding to see the impact we’re making and I’d love to tell you more about it.\n\nAlso, since you are the ultimate movie buff, I wanted to get your recommendations on the latest sci-fi flicks. I’m planning a weekend movie marathon, and nobody curates a list quite like you do!\n\nLooking forward to hearing from you soon. Let's try to set up a call or something sometime next month?\n\nBest wishes,\n\nAlejo Villaverde Morcillo\n\nP.S.: If you’re up for a little nostalgia, I stumbled upon some old photos from our hiking days. Happy to share them if you’re interested!\n\nSent on: Monday, 28th February 2005"},{"content":"{\"fields_to_redact\":[{\"string\":\"James\",\"pii_type\":\"person_name\"},{\"string\":\"Seville\",\"pii_type\":\"street_address\"},{\"string\":\"Alejo Villaverde Morcillo\",\"pii_type\":\"person_name\"},{\"string\":\"Monday, 28th February 2005\",\"pii_type\":\"date\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a support_email. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"Subject: Urgent Assistance Required for Jessica Anderson\n\nFrom: janice34@example.net \nDate: 2024-07-30 \nTo: support@pmp-services.com \n\nHi Support Team,\n\nI hope this email finds you well. My name is Jessica Anderson, and I am having some trouble accessing my account with your organization, Pichardo, Montez y Padilla. I attempted several times today and was unable to log in successfully.\n\nDetails for verification:\n- Date of Birth: 1996-10-27\n- Personal ID: 564 588 200\n- Banking Number: EALY56899882811142\n- Registered Phone Number: +44(0)1174960628\n\nThe error message said something about account authorization, which I haven't encountered before. I suspect it may have something to do with recent changes in my profile information that I updated last week. Could you please verify my account status and assist me in restoring access as soon as possible?\n\nThank you for your prompt attention to this matter. I'm looking forward to your urgent response.\n\nBest regards, \nJessica Anderson"},{"content":"{\"fields_to_redact\":[{\"string\":\"Jessica Anderson\",\"pii_type\":\"person_name\"},{\"string\":\"janice34@example.net\",\"pii_type\":\"email_address\"},{\"string\":\"Pichardo, Montez y Padilla\",\"pii_type\":\"organization_name\"},{\"string\":\"1996-10-27\",\"pii_type\":\"date_of_birth\"},{\"string\":\"564 588 200\",\"pii_type\":\"personal_id\"},{\"string\":\"EALY56899882811142\",\"pii_type\":\"banking_number\"},{\"string\":\"+44(0)1174960628\",\"pii_type\":\"phone_number\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a utility_bill. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"```\nSunshine Power & Utilities Co.\nCustomer Service Hotline: 1-800-432-0098\nEmail: support@sunpowerutilities.net\nWebsite: www.sunpowerutilities.net\n\n---------------------------------------------------------------------\nUTILITY BILL STATEMENT\n---------------------------------------------------------------------\n\nAccount Number: 2049-987-231 Date: 1995-07-11\n\nBilling Summary:\n----------------------------------------------\nName: Mtro. Catalina Olvera\nService Address: 749 Danielle Lakes Apt. 791\n Lake Davidton, NU E3L 3V2\n\n---------------------------------------------------------------------\n\nElectricity Consumption Details:\n----------------------------------------------\nBilling Period: Jun 01, 1995 - Jun 30, 1995\nMeter Number: EL-67234-NE\n\nPrevious Reading: 12,345 kWh\nCurrent Reading: 12,910 kWh\nTotal Usage: 565 kWh\n\nCurrent Charge:\nElectricity Charge: $0.12/kWh\nTotal Electricity Cost: $67.80\n\nOther Charges:\n- Environment Fee: $4.50\n- Service Connection Fee: $15.00\n\nTotal Amount Due: $87.30\n\n---------------------------------------------------------------------\n\nPayment Due Date: July 26, 1995\n\nPayment Methods:\n- By Phone: Call 1-800-123-5678 with your account number\n- Online: Log in to your account at www.sunpowerutilities.net\n- Mobile App: Available on iOS and Android\n- In-Person: Visit our office at 123 Solar Street, Lake Davidton\n- Mail: Use the return envelope enclosed with this bill\n\nNote: Late fees apply if payment is not received by the due date.\n\nThank you for using Sunshine Power & Utilities Co.!\n\nPlease remember to conserve energy:\n- Turn off lights when not in use.\n- Use energy-efficient appliances.\n- Set your thermostat wisely.\n\n---------------------------------------------------------------------\n\nVisit our website for tips on energy conservation and to learn more\nabout our renewable energy projects!\n---------------------------------------------------------------------\n```\n"},{"content":"{\"fields_to_redact\":[{\"string\":\"support@sunpowerutilities.net\",\"pii_type\":\"email_address\"},{\"string\":\"2049-987-231\",\"pii_type\":\"personal_id\"},{\"string\":\"1995-07-11\",\"pii_type\":\"date\"},{\"string\":\"Mtro. Catalina Olvera\",\"pii_type\":\"person_name\"},{\"string\":\"749 Danielle Lakes Apt. 791\\n Lake Davidton, NU E3L 3V2\",\"pii_type\":\"street_address\"},{\"string\":\"123 Solar Street, Lake Davidton\",\"pii_type\":\"street_address\"},{\"string\":\"1995\",\"pii_type\":\"date\"},{\"string\":\"1995\",\"pii_type\":\"date\"},{\"string\":\"July 26, 1995\",\"pii_type\":\"date\"}]}","refusal":null,"role":"assistant"}],"rejected_message":{"content":"{\"fields_to_redact\":[{\"string\":\"support@sunpowerutilities.net\",\"pii_type\":\"email_address\"},{\"string\":\"Mtro. Catalina Olvera\",\"pii_type\":\"person_name\"},{\"string\":\"749 Danielle Lakes Apt. 791\\n Lake Davidton, NU E3L 3V2\",\"pii_type\":\"street_address\"},{\"string\":\"2049-987-231\",\"pii_type\":\"personal_id\"},{\"string\":\"1995-07-11\",\"pii_type\":\"date\"},{\"string\":\"June 01, 1995 - June 30, 1995\",\"pii_type\":\"date\"},{\"string\":\"July 26, 1995\",\"pii_type\":\"date\"}]}","refusal":null,"role":"assistant","tool_calls":[]},"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a rental_agreement. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"**RESIDENTIAL RENTAL AGREEMENT**\n\n**This Residential Lease Agreement (\"Agreement\") is entered into on the 16th day of April, 1983, by and between Patel Inc (\"Landlord\") and Brian Smith (\"Tenant\").**\n\n**Landlord:**\nPatel Inc \nRegistered Office: 47 Zenith Plaza \nCorporate ID: ALPHA-83920\n\n**Tenant:**\nBrian Smith \nContact Number: +44(0)292018253 \nCurrent Address: Calzada Argentina 247 Edif. 025, Depto. 150 \nVieja Namibia, BCS 32896-6646\n\n**1. TERM:**\nThe term of this lease shall commence on the 1st day of May, 1983, and shall continue on a month-to-month basis unless terminated as provided herein.\n\n**2. RENT:**\nTenant agrees to pay Landlord a monthly rent of £850, due on the 1st day of each month. Payments shall be made by bank transfer to Landlord's designated account.\n\n**3. SECURITY DEPOSIT:**\nA security deposit of £1,000 is due at the signing of this Agreement and is refundable upon lease termination, subject to the terms herein.\n\n**4. PREMISES:**\nThe Landlord leases to the Tenant the residential premises located at Calzada Argentina 247 Edif. 025, Depto. 150, Vieja Namibia, BCS 32896-6646.\n\n**5. UTILITIES:**\nTenant will be responsible for all utilities and services incurred at the premises except for water and maintenance, which shall be paid by the Landlord.\n\n**6. MAINTENANCE AND REPAIRS:**\nTenant shall maintain the premises in a clean and sanitary condition. Tenant shall be responsible for any repairs due to negligence or misuse, excluding normal wear and tear.\n\n**7. ALTERATIONS:**\nTenant shall not make any alterations or additions to the premises without the prior written consent of Landlord. Any improvements will become the property of the Landlord.\n\n**8. ACCESS:**\nLandlord reserves the right to enter the premises for inspection, repairs, or showings, provided that 24 hours' notice is given, except in emergency situations.\n\n**9. TERMINATION:**\nEither party may terminate this Agreement by providing a 30-day written notice. Landlord may terminate immediately if Tenant violates any terms of this Agreement.\n\n**Tenant Signatory**\n\n_________________________ \nBrian Smith\n\n**Landlord Signatory**\n\n_________________________ \nAuthorized Representative \nPatel Inc\n\n**Witnesses:**\n\n1. _________________________ Date: ___________\n\n2. _________________________ Date: ___________\n\n**This document constitutes the entire agreement between the parties and supersedes all previous representations or agreements.**"},{"content":"{\"fields_to_redact\":[{\"string\":\"April, 1983\",\"pii_type\":\"date\"},{\"string\":\"Brian Smith\",\"pii_type\":\"person_name\"},{\"string\":\"+44(0)292018253\",\"pii_type\":\"phone_number\"},{\"string\":\"Calzada Argentina 247 Edif. 025, Depto. 150, Vieja Namibia, BCS 32896-6646\",\"pii_type\":\"street_address\"},{\"string\":\"May, 1983\",\"pii_type\":\"date\"},{\"string\":\"ALPHA-83920\",\"pii_type\":\"other_id\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} +{"messages":[{"role":"system","content":"\n You are an expert model trained to redact potentially sensitive information from documents. You have been given a document to redact. The document is a rental_agreement. Your goal is to accurately redact the sensitive information from the document. Sensitive information can be in one of the following categories:\n \n - age: a person's age\n- credit_card_info: a credit card number, expiration date, CCV, etc.\n- nationality: a country when used to reference place of birth, residence, or citizenship\n- date: a specific calendar date\n- date_of_birth: a specific calendar date representing birth\n- domain_name: a domain on the internet\n- email_address: an email ID\n- demographic_group: Anything that identifies race or ethnicity\n- gender: a gender identifier\n- personal_id: Any ID string like a national ID, subscriber number, etc.\n- other_id: Any ID not associated with a person like an organization ID, database ID, etc.\n- banking_number: a number associated with a bank account\n- medical_condition: A diagnosis, treatment code or other information identifying a medical condition\n- organization_name: name of an organization\n- person_name: name of a person\n- phone_number: a telephone number\n- street_address: a physical address\n- password: a secure string used for authentication\n- secure_credential: any secure credential like an API key, private key, 2FA token\n- religious_affiliation: anything that identifies religious affiliation\n \n You should return the specific string that needs to be redacted, along with the category of sensitive information that it belongs to. If there is no sensitive information in the document, return no strings. \n "},{"role":"user","content":"**Residental Lease Agreement**\n\n**This Agreement is made on the 28th day of February, 2003**\n\n**BETWEEN**\n\nLessor: Hillcrest Property Management \nAddress: 345 Sunrise Avenue, Alexbury, LN3 9UP \nAND \nLessee: Jennifer Payne \nPersonal ID: 14698103794 \nAddress: 49 Collins Fall, Alexbury, LN2 8WR \n\n**TERMS OF RENTAL**\n\n1. **Property Description:** \n The property is a 2-bedroom apartment located at 49 Collins Fall, Alexbury, LN2 8WR. \n\n2. **Term of Lease:** \n The lease commences on March 1, 2003, and ends on February 28, 2004. Upon expiration, the lease may be extended or renewed by mutual agreement of both parties.\n\n3. **Rent:** \n The monthly rent shall be £850, payable in advance on the first day of each month to Hillcrest Property Management at the address provided above.\n\n4. **Security Deposit:** \n A security deposit of £850 shall be provided by the Lessee upon signing of this Agreement. This deposit will be held as security for any damages caused to the property or failure to fulfill the terms of the lease.\n\n5. **Utilities:** \n The Lessee agrees to assume responsibility for all utility services such as water, gas, electricity, internet and any other relevant utilities, starting from the commencement date of the lease.\n\n6. **Agreement of Use:** \n The premises shall only be used for residential purposes and occupied by the Lessee. No subletting or assignment of the lease is permitted without prior written consent from the Lessor.\n\n7. **Inspection:** \n The Lessor reserves the right to inspect the premises upon giving prior reasonable notice to ensure that the terms of the agreement are being upheld.\n\n8. **Maintenance and Repairs:** \n The Lessee must maintain the premises in good condition and report any required repairs to the Lessor in a timely manner. The Lessee shall be liable for costs due to negligence or excessive damage.\n\n9. **Termination of Lease:** \n Either party may terminate this lease upon 30-day written notice, with reasons being either breach of agreement or mutual desire to end tenancy.\n\n10. **Notices:** \n All notices under this lease shall be in writing and sent to the addresses mentioned herein.\n\n**Signatures:** \n\nLessor: _____________________ Date: _____________ \nLessee, Jennifer Payne: _____________________ Date: 2003-02-28 \n\n---\n\n*Note: This Lease Agreement is a legally binding contract. Ensure you have read and understood all terms and conditions before signing.*"},{"content":"{\"fields_to_redact\":[{\"string\":\"February, 2003\",\"pii_type\":\"date\"},{\"string\":\"Jennifer Payne\",\"pii_type\":\"person_name\"},{\"string\":\"14698103794\",\"pii_type\":\"personal_id\"},{\"string\":\"49 Collins Fall, Alexbury, LN2 8WR\",\"pii_type\":\"street_address\"},{\"string\":\"March 1, 2003\",\"pii_type\":\"date\"},{\"string\":\"February 28, 2004\",\"pii_type\":\"date\"},{\"string\":\"Jennifer Payne\",\"pii_type\":\"person_name\"},{\"string\":\"2003-02-28\",\"pii_type\":\"date\"}]}","refusal":null,"role":"assistant"}],"tools":[],"response_format":{"type":"json_schema","json_schema":{"name":"RedactionResponse","schema":{"type":"object","$defs":{"RedactionField":{"type":"object","title":"RedactionField","required":["string","pii_type"],"properties":{"string":{"type":"string","title":"String","description":"The exact matching string to redact. Include any whitespace or punctuation. Must be an exact string match!"},"pii_type":{"enum":["age","credit_card_info","nationality","date","date_of_birth","domain_name","email_address","demographic_group","gender","personal_id","other_id","banking_number","medical_condition","organization_name","person_name","phone_number","street_address","password","secure_credential","religious_affiliation"],"type":"string","title":"Pii Type"}},"additionalProperties":false}},"title":"RedactionResponse","required":["fields_to_redact"],"properties":{"fields_to_redact":{"type":"array","items":{"$ref":"#/$defs/RedactionField"},"title":"Fields To Redact"}},"additionalProperties":false},"strict":true}},"split":"TRAIN"} diff --git a/dev/sft/distillation.py b/dev/sft/distillation.py new file mode 100644 index 000000000..95693c1a4 --- /dev/null +++ b/dev/sft/distillation.py @@ -0,0 +1,66 @@ +"""Distillation example: Train a small model using completions from a large model.""" + +import asyncio +import os + +from dotenv import load_dotenv +from openai import AsyncOpenAI + +import art +from art.local import LocalBackend +from art.utils.sft import create_sft_dataset_iterator + +load_dotenv() + +if not os.environ.get("OPENROUTER_API_KEY"): + raise ValueError("OPENROUTER_API_KEY environment variable is required") + +TEACHER_MODEL = "z-ai/glm-5" +STUDENT_BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" +PROMPT = "Explain the concept of recursion in programming with a simple example." + + +async def main(): + # Get completion from teacher model + teacher_client = AsyncOpenAI( + api_key=os.environ["OPENROUTER_API_KEY"], + base_url="https://openrouter.ai/api/v1", + ) + + print(f"Getting completion from teacher model ({TEACHER_MODEL})...") + completion = await teacher_client.chat.completions.create( + model=TEACHER_MODEL, + messages=[{"role": "user", "content": PROMPT}], + ) + teacher_response = completion.choices[0].message.content + print( + f"Teacher response ({len(teacher_response)} chars):\n{teacher_response[:500]}..." + ) + + # Create trajectories from teacher completion + trajectories = [ + art.Trajectory( + messages_and_choices=[ + {"role": "user", "content": PROMPT}, + {"role": "assistant", "content": teacher_response}, + ], + ) + ] + + # Train student model + backend = LocalBackend() + student = art.TrainableModel( + name="sft-distillation-001", + project="sft-distillation", + base_model=STUDENT_BASE_MODEL, + ) + await student.register(backend) + + print(f"Training student model ({STUDENT_BASE_MODEL})...") + for chunk in create_sft_dataset_iterator(trajectories, peak_lr=2e-4): + await student.train_sft(chunk.trajectories, chunk.config) + print("Training complete!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev/sft/sft-from-file.py b/dev/sft/sft-from-file.py new file mode 100644 index 000000000..df66e61cd --- /dev/null +++ b/dev/sft/sft-from-file.py @@ -0,0 +1,35 @@ +"""Simple SFT training script using train_sft_from_file helper.""" + +import asyncio +import random + +import art +from art.local import LocalBackend +from art.utils.sft import train_sft_from_file + + +async def main(): + backend = LocalBackend() + + model_name = "run-" + "".join( + random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8) + ) + model = art.TrainableModel( + name=model_name, + project="sft-from-file", + base_model="meta-llama/Llama-3.1-8B-Instruct", + ) + await model.register(backend) + + await train_sft_from_file( + model=model, + file_path="dev/sft/dataset.jsonl", + epochs=1, + peak_lr=2e-4, + ) + + print("Training complete!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev/sft/sft-warmup.py b/dev/sft/sft-warmup.py new file mode 100644 index 000000000..7a0244039 --- /dev/null +++ b/dev/sft/sft-warmup.py @@ -0,0 +1,127 @@ +"""Minimal example demonstrating SFT -> RL -> SFT switching.""" + +import asyncio +import os +import random + +from dotenv import load_dotenv + +import art +from art.local import LocalBackend +from art.utils.sft import create_sft_dataset_iterator + +# Simple SFT trajectories - teach model to respond "maybe" +SFT_TRAJECTORIES = [ + art.Trajectory( + messages_and_choices=[ + {"role": "user", "content": "respond with yes, no, or maybe"}, + {"role": "assistant", "content": "maybe"}, + ], + ), +] * 100 + + +async def rl_rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory: + """Single RL rollout with reward based on response.""" + messages: art.Messages = [{"role": "user", "content": prompt}] + client = model.openai_client() + + completion = await client.chat.completions.create( + messages=messages, + model=model.get_inference_name(), + max_tokens=10, + timeout=30, + ) + choice = completion.choices[0] + content = choice.message.content or "" + + # Reward: "maybe" > "no" > "yes" > other + reward = {"maybe": 1.0, "no": 0.75, "yes": 0.5}.get(content.strip().lower(), 0.0) + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +async def main(): + load_dotenv() + + backend = LocalBackend() + model_name = "sft-warmup-" + "".join( + random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8) + ) + model = art.TrainableModel( + name=model_name, + project="sft-warmup", + base_model="Qwen/Qwen2.5-7B-Instruct", + ) + await model.register(backend) + + # ======================================================================== + # Phase 1: SFT + # ======================================================================== + print("\n[Phase 1] SFT training...") + for chunk in create_sft_dataset_iterator( + SFT_TRAJECTORIES, batch_size=1, peak_lr=1e-5 + ): + await model.train_sft(chunk.trajectories, chunk.config) + print("SFT phase 1 complete.") + + # ======================================================================== + # Phase 2: RL (GRPO) + # ======================================================================== + print("\n[Phase 2] RL training...") + prompt = "respond with yes, no, or maybe" + + for i in range(10): + print(f" RL step {i + 1}") + train_groups = await art.gather_trajectory_groups( + [ + art.TrajectoryGroup(rl_rollout(model, prompt) for _ in range(6)) + for _ in range(12) + ] + ) + await model.train(train_groups) + print("RL phase 2 complete.") + + # ======================================================================== + # Phase 3: SFT again + # ======================================================================== + print("\n[Phase 3] SFT training again...") + for chunk in create_sft_dataset_iterator( + SFT_TRAJECTORIES, batch_size=1, peak_lr=1e-5 + ): + await model.train_sft(chunk.trajectories, chunk.config) + print("SFT phase 3 complete.") + + # ======================================================================== + # Phase 4: RL (GRPO) again + # ======================================================================== + print("\n[Phase 4] RL training...") + prompt = "respond with yes, no, or maybe" + + for i in range(10): + print(f" RL step {i + 1}") + train_groups = await art.gather_trajectory_groups( + [ + art.TrajectoryGroup(rl_rollout(model, prompt) for _ in range(6)) + for _ in range(12) + ] + ) + await model.train(train_groups) + print("RL phase 4 complete.") + + # ======================================================================== + # Test: Check model output + # ======================================================================== + print("\n[Test] Model output after training:") + client = model.openai_client() + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "respond with yes, no, or maybe"}], + model=model.get_inference_name(), + max_tokens=10, + ) + print(f"Response: {completion.choices[0].message.content}") + + print("\nAll phases complete!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/docs.json b/docs/docs.json index 4f07591ac..99f5675c7 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -57,7 +57,8 @@ "fundamentals/training-loop", "fundamentals/art-client", "fundamentals/art-backend", - "fundamentals/ruler" + "fundamentals/ruler", + "fundamentals/sft-training" ] }, { diff --git a/docs/fundamentals/sft-training.mdx b/docs/fundamentals/sft-training.mdx new file mode 100644 index 000000000..8ab1d2fca --- /dev/null +++ b/docs/fundamentals/sft-training.mdx @@ -0,0 +1,256 @@ +--- +title: "SFT Training" +description: "Train models using supervised fine-tuning with ART." +sidebarTitle: "SFT Training" +icon: "graduation-cap" +--- + +**Supervised fine-tuning (SFT)** trains a model on labeled chat examples rather than through trial-and-error with rewards. It's useful for **distillation** (training a smaller model on outputs from a larger teacher model), **teaching a specific output style or format**, and **warming up** a model before RL training so it starts from a stronger baseline. + +ART supports SFT on both `LocalBackend` and `ServerlessBackend`. + +## Data format + +SFT training data is a JSONL file where each line is a JSON object with `messages` and optionally `tools`. Here's a simple example: + +```json +{ + "messages": [ + { "role": "system", "content": "You are a helpful assistant" }, + { "role": "user", "content": "What is the capital of Tasmania?" }, + { "role": "assistant", "content": "Hobart" } + ] +} +``` + +To train on tool-call conversations, include a `tools` array and `tool_calls` in the assistant message: + +```json +{ + "messages": [ + { "role": "system", "content": "You are a helpful assistant" }, + { "role": "user", "content": "What's the weather in Hobart?" }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Hobart\"}" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": "15°C, partly cloudy" + }, + { + "role": "assistant", + "content": "It's currently 15°C and partly cloudy in Hobart." + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather", + "parameters": { + "type": "object", + "properties": { "location": { "type": "string" } } + } + } + } + ] +} +``` + +Each line must follow these rules: + +- **`messages`** (required) — a non-empty list of chat messages. Each message has a `role` (`system`, `user`, `assistant`, or `tool`) and `content`. The last message **must** be from the `assistant` role. +- **`tools`** (optional) — a list of tool/function definitions, following the [OpenAI tool format](https://platform.openai.com/docs/api-reference/chat). + +Messages follow the [OpenAI chat format](https://platform.openai.com/docs/api-reference/chat), including support for `tool_calls` in assistant messages. + + + Only the assistant's response tokens contribute to the training loss. + Instruction and user tokens are automatically masked so the model learns to + produce better responses without memorizing prompts. + + +## Training from a JSONL file + +For large datasets, use `train_sft_from_file`. It handles batching and applies a learning rate schedule automatically. + +```python +import asyncio +import art +from art.local import LocalBackend +# from art.serverless.backend import ServerlessBackend +from art.utils.sft import train_sft_from_file + +async def main(): + backend = LocalBackend() + # backend = ServerlessBackend() # or use serverless for managed GPUs + model = art.TrainableModel( + name="my-sft-model", + project="sft-project", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + await model.register(backend) + + await train_sft_from_file( + model=model, + file_path="data/train.jsonl", + epochs=3, + batch_size=2, + peak_lr=2e-4, + schedule_type="cosine", + warmup_ratio=0.1, + verbose=True, + ) + +asyncio.run(main()) +``` + +## Distillation + +Distillation trains a smaller model on completions from a larger teacher model. Generate responses from the teacher, wrap them as trajectories, and fine-tune: + +```python +import asyncio +from openai import AsyncOpenAI +import art +from art.local import LocalBackend +# from art.serverless.backend import ServerlessBackend +from art.utils.sft import create_sft_dataset_iterator + +TEACHER_MODEL = "z-ai/glm-5" + +async def main(): + teacher_client = AsyncOpenAI( + api_key="your-api-key", + base_url="https://openrouter.ai/api/v1", + ) + # Small models often produce malformed JSON or miss fields. + # Distilling from a larger model teaches consistent structured extraction. + system_prompt = "Extract {name, role, company} as JSON from the text. Return only valid JSON." + inputs = [ + "Hi, I'm Sarah Chen, VP of Engineering at Acme Corp.", + "David Park here — senior data scientist at Globex.", + "I'm Maria Lopez. I lead product at Initech.", + "Hey, this is James Wu from Umbrella Corp, working as a DevOps engineer.", + "My name is Aisha Patel and I'm a research lead at DeepMind.", + # ... more inputs + ] + + trajectories = [] + for text in inputs: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text}, + ] + completion = await teacher_client.chat.completions.create( + model=TEACHER_MODEL, + messages=messages, + ) + trajectories.append(art.Trajectory( + messages_and_choices=[ + *messages, + {"role": "assistant", "content": completion.choices[0].message.content}, + ], + )) + + # Train student model on teacher outputs + backend = LocalBackend() + # backend = ServerlessBackend() # or use serverless for managed GPUs + student = art.TrainableModel( + name="distillation-001", + project="sft-distillation", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + await student.register(backend) + + # create_sft_dataset_iterator computes the LR schedule (warmup + decay) over + # the full dataset, then slices it correctly across chunks. Each + # chunk's train_sft call logs its own metrics, giving you granular + # loss curves instead of a single aggregated number. + for chunk in create_sft_dataset_iterator(trajectories, peak_lr=2e-4): + await student.train_sft(chunk.trajectories, chunk.config) + +asyncio.run(main()) +``` + +## SFT as warmup before RL + +A common pattern is to run SFT first to give the model a head start, then switch to RL for further improvement. ART supports switching between SFT and RL training seamlessly within the same run: + +```python +import art +from art.local import LocalBackend +# from art.serverless.backend import ServerlessBackend +from art.utils.sft import train_sft_from_file + +async def main(): + backend = LocalBackend() + # backend = ServerlessBackend() # or use serverless for managed GPUs + model = art.TrainableModel( + name="warmup-then-rl", + project="my-project", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + await model.register(backend) + + # Phase 1: SFT warmup from a dataset + await train_sft_from_file( + model=model, + file_path="data/train.jsonl", + epochs=3, + ) + + # Phase 2: RL training picks up from the SFT checkpoint + from my_project import rollout, scenarios + for step in range(await model.get_step(), 50): + train_groups = await art.gather_trajectory_groups( + [ + art.TrajectoryGroup(rollout(model, scenario) for _ in range(8)) + for scenario in scenarios + ] + ) + await model.train(train_groups) +``` + +This works because both SFT and RL train the same LoRA adapter. After SFT completes, RL continues from the updated weights. + +## Local vs Serverless + +Both backends support SFT with the same API. The key differences are in how training executes: + +| | LocalBackend | ServerlessBackend | +| --------------- | ------------------------------------ | -------------------------------------------------- | +| **Execution** | Trains on your local GPU | Sends data to remote managed GPUs | +| **Checkpoints** | Saved as LoRA adapters in `.art/` | Stored as W&B Artifacts | +| **Inference** | You deploy the LoRA adapter yourself | Production-ready inference endpoint out of the box | +| **Best for** | Development, iteration, full control | Production, no local GPU, large-scale training | + +The `ServerlessBackend` requires a W&B API key. See the [backend docs](/fundamentals/art-backend) for setup instructions. + +```python +# Serverless — same API, training runs remotely +from art.serverless.backend import ServerlessBackend + +backend = ServerlessBackend() # uses WANDB_API_KEY env var +model = art.TrainableModel( + name="my-sft-model", + project="sft-project", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", +) +await model.register(backend) + +await model.train_sft(trajectories, config=art.TrainSFTConfig(learning_rate=5e-5)) +``` diff --git a/pyproject.toml b/pyproject.toml index e2934df55..e5c85fc83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dependencies = [ "openai>=2.14.0", "typer>=0.15.2", "litellm>=1.71.1", - "weave>=0.52.23", + "weave>=0.52.24", "polars>=1.26.0", "tblib>=3.0.0", "nest-asyncio>=1.6.0", @@ -29,7 +29,7 @@ backend = [ "accelerate==1.7.0", "awscli>=1.38.1", "setuptools>=78.1.0", - "wandb==0.23.1", + "wandb==0.24.0", "transformers>=4.55.2,<=4.57.3", "duckdb>=1.0.0", "pyarrow>=15.0.0", @@ -72,6 +72,9 @@ allow-direct-references = true [tool.hatch.build.targets.wheel] packages = ["src/art", "src/mp_actors"] +[tool.hatch.build.targets.wheel.force-include] +".agents/skills" = "art/skills" + [tool.hatch.build] sources = ["src"] @@ -105,6 +108,10 @@ asyncio_mode = "auto" [tool.uv] required-version = ">=0.6.15" +# Override numpy to <2.0 for compatibility with megatron-core in the training +# environment. vLLM 0.15.1 pulls opencv-python-headless>=4.13 which wants +# numpy>=2 on Python 3.9+, but megatron-core requires numpy<2. +override-dependencies = ["numpy<2"] [tool.ty.environment] python-version = "3.11" diff --git a/src/art/__init__.py b/src/art/__init__.py index 3a8e048b2..3272944da 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -65,6 +65,7 @@ def __init__(self, **kwargs): Tools, TrainConfig, TrainResult, + TrainSFTConfig, ) from .utils import retry from .yield_trajectory import capture_yielded_trajectory, yield_trajectory @@ -87,6 +88,7 @@ def __init__(self, **kwargs): "Model", "TrainableModel", "retry", + "TrainSFTConfig", "TrainConfig", "TrainResult", "Trajectory", diff --git a/src/art/backend.py b/src/art/backend.py index 3aa77a6bd..4054d0234 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Protocol, TypeAlias from . import dev -from .trajectories import TrajectoryGroup -from .types import TrainConfig, TrainResult +from .trajectories import Trajectory, TrajectoryGroup +from .types import TrainConfig, TrainResult, TrainSFTConfig if TYPE_CHECKING: from .model import Model, TrainableModel @@ -50,3 +50,12 @@ def _train_model( dev_config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: ... + + def _train_sft( + self, + model: AnyTrainableModel, + trajectories: Iterable[Trajectory], + config: TrainSFTConfig, + dev_config: dev.TrainSFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: ... diff --git a/src/art/cli.py b/src/art/cli.py index 09b32e9de..2a40280ae 100644 --- a/src/art/cli.py +++ b/src/art/cli.py @@ -27,6 +27,168 @@ app = typer.Typer() +SKILL_NAMES = ["train-sft", "train-rl"] + +WANDB_INFERENCE_BASE_URL = "https://api.inference.wandb.ai/v1" +WANDB_INFERENCE_MODEL = "Qwen/Qwen3-235B-A22B-Instruct-2507" + + +def _get_skill_path(skill_name: str) -> Path: + """Find a skill file, checking installed package first, then repo root.""" + # Installed from wheel: art/skills/ in site-packages + pkg_path = Path(__file__).parent / "skills" / skill_name / "SKILL.md" + if pkg_path.exists(): + return pkg_path + # Development: .agents/skills/ in repo root + dev_path = ( + Path(__file__).parent.parent.parent + / ".agents" + / "skills" + / skill_name + / "SKILL.md" + ) + if dev_path.exists(): + return dev_path + raise FileNotFoundError(f"Skill '{skill_name}' not found") + + +def _chat_with_skill(skill_name: str) -> None: + """Run an interactive chat session using a skill as the system prompt.""" + import os + import sys + + from openai import OpenAI + + api_key = os.environ.get("WANDB_API_KEY") + if not api_key: + typer.echo( + "Error: WANDB_API_KEY environment variable is required.\n" + "Get your key at https://wandb.ai/authorize", + err=True, + ) + raise typer.Exit(1) + + try: + skill_path = _get_skill_path(skill_name) + except FileNotFoundError: + typer.echo( + f"Error: Skill '{skill_name}' not found in this installation.", err=True + ) + raise typer.Exit(1) + + skill_content = skill_path.read_text() + # Strip YAML frontmatter + if skill_content.startswith("---"): + end = skill_content.find("---", 3) + if end != -1: + skill_content = skill_content[end + 3 :].strip() + + cli_preamble = ( + "IMPORTANT: You are in a plain text chat. You have NO tools. " + "You cannot search files, read files, run scripts, validate data, or execute code. " + "Never say 'Let me search...', 'Scanning...', 'Found N files...', or 'Valid! N rows' — " + "you did not do any of these things. " + "If a step requires running code or using tools, skip it entirely and move on.\n\n" + ) + + client = OpenAI(base_url=WANDB_INFERENCE_BASE_URL, api_key=api_key) + messages: list = [{"role": "system", "content": cli_preamble + skill_content}] + + typer.echo(f"ART {skill_name} wizard (powered by {WANDB_INFERENCE_MODEL})") + typer.echo("Type 'quit' to exit.\n") + + # Send an initial empty user message to kick off the wizard + messages.append({"role": "user", "content": "Hi, let's get started."}) + + while True: + try: + stream = client.chat.completions.create( + model=WANDB_INFERENCE_MODEL, + messages=messages, + stream=True, + ) + assistant_message = "" + for chunk in stream: + if not chunk.choices: + continue + delta = chunk.choices[0].delta.content + if delta: + sys.stdout.write(delta) + sys.stdout.flush() + assistant_message += delta + typer.echo("") # newline after streaming + messages.append({"role": "assistant", "content": assistant_message}) + except Exception as e: + typer.echo(f"\nError from inference API: {e}", err=True) + raise typer.Exit(1) + + try: + user_input = typer.prompt("\nYou") + except (KeyboardInterrupt, EOFError): + typer.echo("\nExiting.") + break + + if user_input.strip().lower() == "quit": + break + + messages.append({"role": "user", "content": user_input}) + + +@app.command() +def train_sft() -> None: + """Interactive wizard to create an SFT training script.""" + _chat_with_skill("train-sft") + + +@app.command() +def train_rl() -> None: + """Interactive wizard to create an RL training script.""" + _chat_with_skill("train-rl") + + +@app.command() +def install_skills( + path: Path = typer.Argument( + default=Path("."), help="Project directory to install skills into" + ), +) -> None: + """Install ART agent skills for Claude Code and OpenAI Codex. + + Copies bundled SKILL.md files into .claude/skills/ and .agents/skills/ + in the target project directory. + + Examples: + art install-skills + art install-skills /path/to/my-project + """ + import shutil + + target = path.resolve() + destinations = [ + target / ".claude" / "skills", + target / ".agents" / "skills", + ] + + installed = [] + for dest_root in destinations: + for skill_name in SKILL_NAMES: + try: + src = _get_skill_path(skill_name) + except FileNotFoundError: + continue + dest_dir = dest_root / skill_name + dest_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dest_dir / "SKILL.md") + installed.append(str(dest_dir / "SKILL.md")) + + typer.echo(f"Installed {len(installed)} skill files into {target}:") + for f in installed: + typer.echo(f" {f}") + typer.echo( + "\nUse /train-sft and /train-rl in Claude Code or OpenAI Codex to get started." + ) + + @app.command() def migrate( path: Path = typer.Argument( diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py index 9d04c26bd..8e029139a 100644 --- a/src/art/dev/__init__.py +++ b/src/art/dev/__init__.py @@ -9,7 +9,7 @@ TrainerArgs, ) from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config -from .train import TrainConfig +from .train import TrainConfig, TrainSFTConfig __all__ = [ "EngineArgs", @@ -23,5 +23,6 @@ "get_openai_server_config", "OpenAIServerConfig", "ServerArgs", + "TrainSFTConfig", "TrainConfig", ] diff --git a/src/art/dev/train.py b/src/art/dev/train.py index bd4150740..ca6c6fabb 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -27,3 +27,7 @@ class TrainConfig(TypedDict, total=False): scale_learning_rate_by_reward_std_dev: bool scale_rewards: bool truncated_importance_sampling: float | None + + +class TrainSFTConfig(TypedDict, total=False): + """Experimental SFT configuration options. Use at your own risk.""" diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 11ac1111c..356151242 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -12,6 +12,7 @@ import aiohttp import numpy as np from openai import AsyncOpenAI +import polars as pl import torch from tqdm import auto as tqdm from transformers import AutoImageProcessor, AutoTokenizer @@ -42,9 +43,12 @@ packed_tensors_to_dir, plot_packed_tensors, ) -from ..preprocessing.tokenize import tokenize_trajectory_groups +from ..preprocessing.tokenize import ( + tokenize_sft_batch, + tokenize_trajectory_groups, +) from ..trajectories import Trajectory, TrajectoryGroup -from ..types import LocalTrainResult, Message, TrainConfig +from ..types import LocalTrainResult, Message, TrainConfig, TrainSFTConfig from ..utils import format_message, get_model_step from .checkpoints import ( delete_checkpoints, @@ -645,6 +649,103 @@ async def _train_model( # Note: _get_reward_std_dev_learning_rate_multiplier and _log_metrics # have been moved to the Model class (frontend) + async def _train_sft( + self, + model: AnyTrainableModel, + trajectories: Iterable[Trajectory], + config: TrainSFTConfig, + dev_config: dev.TrainSFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train the model using supervised fine-tuning. + + Args: + model: The trainable model to fine-tune + trajectories: Iterable of Trajectory objects + config: SFT configuration with batch_size and learning rates. + If learning_rate is a list, streaming mode is used automatically. + dev_config: Developer configuration + verbose: Whether to print detailed logs + + Yields: + Dictionary containing training metrics for each batch + """ + if verbose: + print("Starting _train_sft") + + # Get tokenizer + if model.base_model not in self._tokenizers: + self._tokenizers[model.base_model] = AutoTokenizer.from_pretrained( + model.base_model + ) + tokenizer = self._tokenizers[model.base_model] + + # Determine batch_size + batch_size = config.batch_size + if batch_size == "auto": + batch_size = 2 # Default to 2 for SFT + + # Auto-detect instruction/response parts from model + from ..utils.model_config import get_instruction_response_parts + + instruction_part, response_part = get_instruction_response_parts( + model.base_model, tokenizer + ) + + if verbose: + print(f"Using instruction_part: {instruction_part!r}") + print(f"Using response_part: {response_part!r}") + + import itertools + from typing import Iterator + + from ..preprocessing.tokenize import SFTBatch + + if isinstance(config.learning_rate, list): + learning_rates_iter: Iterator[float] = iter(config.learning_rate) + else: + learning_rates_iter = itertools.repeat(config.learning_rate) + + # Build all batches in memory + trajectory_list = list(trajectories) + batches: list[SFTBatch] = [] + for i in range(0, len(trajectory_list), batch_size): + batch_trajectories = trajectory_list[i : i + batch_size] + batches.append( + tokenize_sft_batch( + trajectory_batch=batch_trajectories, + learning_rate=next(learning_rates_iter), + tokenizer=tokenizer, + instruction_part=instruction_part, + response_part=response_part, + ) + ) + + # Get the service and train + service = await self._get_service(model) + + pbar = tqdm.tqdm(total=len(batches), desc="sft train") + total_trainable_tokens = 0 + batch_count = 0 + + async for result in service.train_sft(batches, verbose): + pbar.update(1) + pbar.set_postfix({"loss": f"{result.get('loss', 0):.4f}"}) + total_trainable_tokens += result.get("num_trainable_tokens", 0) + batch_count += 1 + yield result + + pbar.close() + + if batch_count > 0 and total_trainable_tokens == 0: + print( + "WARNING: No trainable tokens found! " + "Check instruction_part and response_part settings." + ) + + if verbose: + print("_train_sft complete") + # ------------------------------------------------------------------ # Experimental support for S3 # ------------------------------------------------------------------ diff --git a/src/art/local/service.py b/src/art/local/service.py index 3f53051ea..09f962db9 100644 --- a/src/art/local/service.py +++ b/src/art/local/service.py @@ -2,6 +2,7 @@ from .. import dev, types from ..preprocessing.pack import DiskPackedTensors +from ..preprocessing.tokenize import SFTBatch @runtime_checkable @@ -28,3 +29,19 @@ def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: ... + + def train_sft( + self, + batches: list[SFTBatch], + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train using SFT on pre-computed batches. + + Args: + batches: List of SFTBatch objects to train on. + verbose: Whether to print detailed logs. + + Yields: + Dictionary containing training metrics for each batch. + """ + ... diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index e0f94367a..8ed6b82cd 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -7,7 +7,7 @@ from pathlib import Path import shutil import subprocess -from typing import AsyncIterator +from typing import Any, AsyncIterator from peft.tuners.lora.config import LoraConfig from pydantic import BaseModel @@ -21,6 +21,7 @@ from .. import dev, types from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.pack import DiskPackedTensors +from ..preprocessing.tokenize import SFTBatch from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir @@ -294,6 +295,15 @@ async def train( await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) await llm.resume_generation() + # SFT not supported for MegatronService + async def train_sft( + self, + batches: list[Any], + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError("SFT training is not supported for MegatronService") + yield {} # Make this a generator + def _merge_lora_adapter(self, lora_path: str) -> None: """Merge sharded LoRA adapters from distributed training.""" base_dir = Path(lora_path) diff --git a/src/art/model.py b/src/art/model.py index 21244ad8c..d8e906b86 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -14,7 +14,7 @@ from . import dev from .costs import CostCalculator from .trajectories import Trajectory, TrajectoryGroup -from .types import TrainConfig +from .types import TrainConfig, TrainSFTConfig from .utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev from .utils.trajectory_logging import write_trajectory_groups_parquet @@ -72,6 +72,7 @@ class Model( project: str entity: str | None = None id: str | None = None + run_id: str | None = None config: ModelConfig # Discriminator field for FastAPI serialization trainable: bool = False @@ -395,6 +396,9 @@ def _log_metrics( prefixed = {f"{split}/{k}": v for k, v in metrics.items()} output_dir = self._get_output_dir() + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + # Write to history.jsonl with open(f"{output_dir}/history.jsonl", "a") as f: f.write( @@ -688,6 +692,7 @@ def __init__( project: str, entity: str | None = None, id: str | None = None, + run_id: str | None = None, config: ModelConfig | None = None, base_model: str, base_path: str = ".art", @@ -708,6 +713,9 @@ def __init__( ) object.__setattr__(self, "_costs_lock", asyncio.Lock()) object.__setattr__(self, "_cost_calculator", self._noop_cost_calculator) + if _internal_config is not None: + # Bypass BaseModel __setattr__ to allow setting private attr + object.__setattr__(self, "_internal_config", _internal_config) @property def cost_calculator(self) -> CostCalculator: @@ -725,9 +733,6 @@ def _noop_cost_calculator( _prompt_tokens: int | None, _completion_tokens: int | None ) -> dict[str, float]: return {} - if _internal_config is not None: - # Bypass BaseModel __setattr__ to allow setting private attr - object.__setattr__(self, "_internal_config", _internal_config) @overload def __new__( @@ -871,10 +876,7 @@ async def train( groups_list = list(trajectory_groups) _config = _config or {} # ty:ignore[invalid-assignment] - # 1. Log trajectories first (frontend handles this now) - await self.log(groups_list, split="train") - - # 2. Train (backend no longer logs internally) + # 1. Train (backend no longer logs internally) training_metrics: list[dict[str, float]] = [] async for metrics in self.backend()._train_model( self, @@ -885,7 +887,8 @@ async def train( ): training_metrics.append(metrics) - # 3. Log training metrics (loss, gradient norms, etc.) + # 2. Calculate aggregated training metrics + avg_metrics: dict[str, float] = {} if training_metrics: avg_metrics = { k: sum(d.get(k, 0) for d in training_metrics) @@ -893,6 +896,52 @@ async def train( for k in {k for d in training_metrics for k in d} if k != "num_gradient_steps" } + + # 3. Log trajectories and training metrics together (single wandb log call) + step = await self.get_step() + await self.log(groups_list, split="train", metrics=avg_metrics, step=step) + + async def train_sft( + self, + trajectories: Iterable[Trajectory], + config: TrainSFTConfig | None = None, + _config: dev.TrainSFTConfig | None = None, + verbose: bool = False, + ) -> None: + """ + Supervised fine-tune the model with an iterable of trajectories. + + Args: + trajectories: An iterable of Trajectory objects. + config: SFT configuration including learning_rates and batch_size. + If None, uses default TrainSFTConfig(). + _config: Additional experimental configuration that is subject to change and + not yet part of the public API. Use at your own risk. + verbose: Whether to print verbose output. + """ + if config is None: + config = TrainSFTConfig() + + # Train (backend yields metrics for each batch without logging) + # Collect all metrics and aggregate them at the end (same as RL) + _config = _config or {} # ty:ignore[invalid-assignment] + training_metrics: list[dict[str, float]] = [] + async for metrics in self.backend()._train_sft( + self, + trajectories, + config, + _config, # ty:ignore[invalid-argument-type] + verbose, + ): + training_metrics.append(metrics) + + # Log aggregated training metrics once (same as RL) + if training_metrics: + avg_metrics = { + k: sum(d.get(k, 0) for d in training_metrics) + / sum(1 for d in training_metrics if k in d) + for k in {k for d in training_metrics for k in d} + } # Get the current step after training step = await self.get_step() self._log_metrics(avg_metrics, "train", step) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 7d30d590a..8fbcedca6 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -357,53 +357,29 @@ def tokenize_trajectory( ) -def tokenize_sft_batches( - trajectories: list[Trajectory], - batch_size: int, - learning_rates: list[float], +def tokenize_sft_batch( + trajectory_batch: list[Trajectory], + learning_rate: float, tokenizer: PreTrainedTokenizerBase, instruction_part: str, response_part: str, -) -> Generator[SFTBatch, None, None]: - """ - Tokenize trajectories into batches for supervised fine-tuning. +) -> SFTBatch: + """Tokenize a single batch of trajectories for SFT. + Args: - trajectories: Flat list of trajectories - batch_size: Number of trajectories per batch - learning_rates: Learning rate for each batch + trajectory_batch: List of trajectories in this batch + learning_rate: Learning rate for this batch tokenizer: Tokenizer to use for encoding - instruction_part: Instruction template part (e.g., "User:") - response_part: Response template part (e.g., "Assistant:") - Yields: - SFTBatch object containing: - - trajectory_tensors: List of tensors for each trajectory - - learning_rate: Learning rate for this batch - - num_trajectories: Number of trajectories in this batch - - num_trainable_tokens: Total number of trainable tokens + instruction_part: Instruction template part (e.g., "<|im_start|>user") + response_part: Response template part (e.g., "<|im_start|>assistant") + + Returns: + SFTBatch object for this batch """ - # Import Unsloth Zoo utility for training on responses only - # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py - # This function handles edge cases with tokenization (newlines, spaces, etc.) + import unsloth # noqa: F401 - Must be imported first to set UNSLOTH_IS_PRESENT env var from unsloth_zoo.dataset_utils import train_on_responses_only - # Validate inputs - num_trajectories = len(trajectories) - num_learning_rates = len(learning_rates) - expected_num_batches = math.ceil(num_trajectories / batch_size) - - if num_learning_rates != expected_num_batches: - raise ValueError( - f"Mismatch between trajectories and learning_rates: " - f"{num_trajectories} trajectories with batch_size={batch_size} " - f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" - ) - - # Handle missing pad_token_id (common for LLaMA and similar models) - pad_token_id = tokenizer.pad_token_id - if pad_token_id is None: - pad_token_id = tokenizer.eos_token_id - - _train_on_responses_only = train_on_responses_only( + train_on_responses_only_fn = train_on_responses_only( trainer=None, instruction_part=instruction_part, response_part=response_part, @@ -411,77 +387,75 @@ def tokenize_sft_batches( tokenizer=tokenizer, return_function=True, ) + # Handle missing pad_token_id (common for LLaMA and similar models) + pad_token_id = tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = tokenizer.eos_token_id - # TODO Process input_ids in batch for better efficiency - for batch_idx, lr in enumerate(learning_rates): - start_idx = batch_idx * batch_size - end_idx = start_idx + batch_size - trajectory_batch = trajectories[start_idx:end_idx] - - # First pass: tokenize all trajectories - tokenized_trajectories = [] - for trajectory in trajectory_batch: - messages = trajectory.messages_and_choices - tools = trajectory.tools + # First pass: tokenize all trajectories + tokenized_trajectories = [] + for trajectory in trajectory_batch: + messages = trajectory.messages_and_choices + tools = trajectory.tools - # Single-step tokenization: apply_chat_template with tokenize=True - input_ids = cast( - list[int], - tokenizer.apply_chat_template( - cast(Any, messages), - tools=cast(Any, tools), - tokenize=True, - add_generation_prompt=False, - ), - ) + # Single-step tokenization: apply_chat_template with tokenize=True + input_ids = cast( + list[int], + tokenizer.apply_chat_template( + cast(Any, messages), + tools=cast(Any, tools), + tokenize=True, + add_generation_prompt=False, + ), + ) - # Create attention mask (all 1s - no padding yet) - attention_mask = [1] * len(input_ids) + # Create attention mask (all 1s - no padding yet) + attention_mask = [1] * len(input_ids) - labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0] + labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0] - tokenized_trajectories.append( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - } - ) + tokenized_trajectories.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + ) - # Find max length in this batch for padding - max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) + # Find max length in this batch for padding + max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) - # Second pass: pad all trajectories to max_seq_length - trajectory_tensors = [] - for tokenized in tokenized_trajectories: - input_ids = tokenized["input_ids"] - attention_mask = tokenized["attention_mask"] - labels = tokenized["labels"] + # Second pass: pad all trajectories to max_seq_length + trajectory_tensors = [] + for tokenized in tokenized_trajectories: + input_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + labels = tokenized["labels"] - # Pad to max_seq_length - padding_length = max_seq_length - len(input_ids) - if padding_length > 0: - input_ids = input_ids + [pad_token_id] * padding_length - attention_mask = attention_mask + [0] * padding_length - labels = labels + [-100] * padding_length + # Pad to max_seq_length + padding_length = max_seq_length - len(input_ids) + if padding_length > 0: + input_ids = input_ids + [pad_token_id] * padding_length + attention_mask = attention_mask + [0] * padding_length + labels = labels + [-100] * padding_length - trajectory_tensor = { - "input_ids": torch.tensor([input_ids], dtype=torch.long), - "attention_mask": torch.tensor([attention_mask], dtype=torch.long), - "labels": torch.tensor([labels], dtype=torch.long), - } + trajectory_tensor = { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "attention_mask": torch.tensor([attention_mask], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), + } - trajectory_tensors.append(trajectory_tensor) + trajectory_tensors.append(trajectory_tensor) - # Calculate total trainable tokens (labels != -100) - num_trainable_tokens = sum( - (tensor_dict["labels"] != -100).sum().item() - for tensor_dict in trajectory_tensors - ) + # Calculate total trainable tokens (labels != -100) + num_trainable_tokens = sum( + (tensor_dict["labels"] != -100).sum().item() + for tensor_dict in trajectory_tensors + ) - yield SFTBatch( - trajectory_tensors=trajectory_tensors, - learning_rate=lr, - num_trajectories=len(trajectory_tensors), - num_trainable_tokens=num_trainable_tokens, - ) + return SFTBatch( + trajectory_tensors=trajectory_tensors, + learning_rate=learning_rate, + num_trajectories=len(trajectory_tensors), + num_trainable_tokens=num_trainable_tokens, + ) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index cabd064c6..4da5fad50 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal import warnings from openai._types import NOT_GIVEN @@ -9,8 +9,8 @@ from .. import dev from ..backend import AnyTrainableModel, Backend -from ..trajectories import TrajectoryGroup -from ..types import ServerlessTrainResult, TrainConfig +from ..trajectories import Trajectory, TrajectoryGroup +from ..types import ServerlessTrainResult, TrainConfig, TrainSFTConfig from ..utils.record_provenance import record_provenance if TYPE_CHECKING: @@ -67,6 +67,7 @@ async def register( ) model.id = client_model.id model.entity = client_model.entity + model.run_id = client_model.run_id async def delete( self, @@ -321,6 +322,175 @@ async def _train_model( raise RuntimeError(f"Training job failed: {error_message}") after = event.id + async def _train_sft( + self, + model: AnyTrainableModel, + trajectories: Iterable[Trajectory], + config: TrainSFTConfig, + dev_config: dev.TrainSFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train the model using supervised fine-tuning. + + For ServerlessBackend, this serializes trajectories to a JSONL file, + uploads it to W&B artifacts, and calls the SFT training API. + + Args: + model: The trainable model to fine-tune. + trajectories: Iterable of Trajectory objects. + config: SFT configuration with batch_size and learning rates. + dev_config: Developer configuration. + verbose: Whether to print detailed logs. + + Yields: + Dictionary containing training metrics for each batch. + """ + import json + import tempfile + import uuid + + import wandb + + assert model.id is not None, "Model ID is required" + + # Get the user's default entity from W&B if not set + if model.entity is None: + api = wandb.Api(api_key=self._client.api_key) + model.entity = api.default_entity + + # Generate unique artifact name to avoid race conditions in distributed systems + artifact_id = uuid.uuid4().hex[:12] + artifact_name = f"{model.name}-sft-data-{artifact_id}" + + if verbose: + print("Serializing trajectories to file (streaming)...") + + # Serialize trajectories to a temporary JSONL file (streaming - no memory load) + num_trajectories = 0 + with tempfile.NamedTemporaryFile( + mode="w", suffix=".jsonl", delete=False + ) as tmp_file: + for trajectory in trajectories: + # Convert trajectory to the expected JSONL format + line: dict[str, Any] = { + "messages": trajectory.messages(), + } + if trajectory.tools: + line["tools"] = trajectory.tools + tmp_file.write(json.dumps(line) + "\n") + num_trajectories += 1 + tmp_file_path = tmp_file.name + + if num_trajectories == 0: + if verbose: + print("No trajectories to train on") + import os + + os.unlink(tmp_file_path) + return + + if verbose: + print(f"Serialized {num_trajectories} trajectories") + + try: + if verbose: + print("Uploading training data to W&B artifacts...") + + # Upload the file to W&B as a dataset artifact + # Use the model's canonical run_id from database, or fall back to model name + run = wandb.init( + name=model.name, + id=model.run_id + or model.name, # Use stored run_id to match the canonical wandb run + entity=model.entity, + project=model.project, + resume="allow", # Resume if this run already exists + settings=wandb.Settings(api_key=self._client.api_key), + ) + try: + artifact = wandb.Artifact( + artifact_name, + type="dataset", + metadata={ + "format": "jsonl", + "num_trajectories": num_trajectories, + }, + ) + artifact.add_file(tmp_file_path, name="train.jsonl") + artifact = run.log_artifact(artifact) + try: + artifact = artifact.wait() + except ValueError as e: + if "Unable to fetch artifact with id" in str(e): + if verbose: + print(f"Warning: {e}") + else: + raise e + finally: + # Finish the run so the workflow can resume it later + # The workflow uses wandb_run with resume="must" to continue this run + run.finish() + finally: + # Clean up temporary file + import os + + os.unlink(tmp_file_path) + + # Construct the artifact URL with unique name (v0 is the first version) + training_data_url = ( + f"wandb-artifact:///{model.entity}/{model.project}/{artifact_name}:v0" + ) + + if verbose: + print(f"Training data uploaded. Artifact URL: {training_data_url}") + print("Starting SFT training job...") + + # Create SFT training job + from .client import SFTTrainingConfig + + sft_config: SFTTrainingConfig = {} + if config.batch_size != "auto": + sft_config["batch_size"] = config.batch_size + sft_config["learning_rate"] = config.learning_rate + + sft_training_job = await self._client.sft_training_jobs.create( + model_id=model.id, + training_data_url=training_data_url, + config=sft_config, + ) + + # Poll for events + after: str | None = None + num_batches: int | None = None + pbar: tqdm.tqdm | None = None + while True: + await asyncio.sleep(1) + async for event in self._client.sft_training_jobs.events.list( + training_job_id=sft_training_job.id, after=after or NOT_GIVEN + ): + if event.type == "gradient_step": + assert pbar is not None and num_batches is not None + pbar.update(1) + pbar.set_postfix(event.data) + yield {**event.data, "num_gradient_steps": num_batches} + elif event.type == "training_started": + num_batches = event.data.get("num_sequences", 0) + if pbar is None: + pbar = tqdm.tqdm(total=num_batches, desc="train sft") + continue + elif event.type == "training_ended": + if pbar is not None: + pbar.close() + return + elif event.type == "training_failed": + if pbar is not None: + pbar.close() + error_message = event.data.get( + "error_message", "SFT training failed with an unknown error" + ) + raise RuntimeError(f"SFT training job failed: {error_message}") + after = event.id + # ------------------------------------------------------------------ # Experimental support for S3 and checkpoints # ------------------------------------------------------------------ diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 076418f39..670b76335 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -30,6 +30,7 @@ class Model(BaseModel): project: str name: str base_model: str + run_id: str | None class Checkpoint(BaseModel): @@ -65,10 +66,19 @@ class ExperimentalTrainingConfig(TypedDict, total=False): scale_rewards: bool | None +class SFTTrainingConfig(TypedDict, total=False): + batch_size: int | None + learning_rate: float | list[float] | None + + class TrainingJob(BaseModel): id: str +class SFTTrainingJob(BaseModel): + id: str + + class TrainingJobEventListParams(TypedDict, total=False): after: str limit: int @@ -220,6 +230,29 @@ def list( ) +class SFTTrainingJobs(AsyncAPIResource): + async def create( + self, + *, + model_id: str, + training_data_url: str, + config: SFTTrainingConfig | None = None, + ) -> SFTTrainingJob: + return await self._post( + "/preview/sft-training-jobs", + cast_to=SFTTrainingJob, + body={ + "model_id": model_id, + "training_data_url": training_data_url, + "config": config, + }, + ) + + @cached_property + def events(self) -> "TrainingJobEvents": + return TrainingJobEvents(cast(AsyncOpenAI, self._client)) + + class Client(AsyncAPIClient): api_key: str @@ -264,6 +297,10 @@ def models(self) -> Models: def training_jobs(self) -> TrainingJobs: return TrainingJobs(cast(AsyncOpenAI, self)) + @cached_property + def sft_training_jobs(self) -> SFTTrainingJobs: + return SFTTrainingJobs(cast(AsyncOpenAI, self)) + ############################ # AsyncOpenAI overrides # ############################ diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index c2d9515f4..b04aec363 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -6,7 +6,7 @@ from pathlib import Path import shutil import time -from typing import AsyncIterator, Generator +from typing import Any, AsyncIterator, Generator import tinker from tinker.lib.public_interfaces.rest_client import RestClient as TinkerRestClient @@ -22,6 +22,7 @@ DiskPackedTensors, packed_tensors_from_dir, ) +from ..preprocessing.tokenize import SFTBatch from .server import OpenAICompatibleTinkerServer @@ -148,6 +149,15 @@ def custom_loss_fn( state.renderer, ) + # SFT not supported for TinkerService + async def train_sft( + self, + batches: list[Any], + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError("SFT training is not supported for TinkerService") + yield {} # Make this a generator + async def delete_checkpoints(self, steps_to_keep: list[int]) -> None: state = await self._state_task # Find steps to delete diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 5a907950d..a9a94e680 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -40,7 +40,7 @@ class Trajectory(pydantic.BaseModel): messages_and_choices: MessagesAndChoices tools: Tools | None = None additional_histories: list[History] = [] - reward: float + reward: float = 0.0 initial_policy_version: int | None = None final_policy_version: int | None = None metrics: dict[str, float | int | bool] = {} diff --git a/src/art/types.py b/src/art/types.py index df81d6842..b37dd1fec 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -19,6 +19,11 @@ class TrainConfig(pydantic.BaseModel): beta: float = 0.0 +class TrainSFTConfig(pydantic.BaseModel): + learning_rate: float | list[float] = 5e-5 # Single value or per-batch list + batch_size: int | Literal["auto"] = "auto" + + Verbosity = Literal[0, 1, 2] diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index d42941357..2417cff96 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from functools import cached_property import os -from typing import TYPE_CHECKING, Any, AsyncIterator, Protocol, cast +from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast from datasets import Dataset import peft @@ -24,6 +24,7 @@ PackedTensors, packed_tensors_from_dir, ) +from ..preprocessing.tokenize import SFTBatch from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers @@ -261,6 +262,7 @@ class UnslothService: config: dev.InternalModelConfig output_dir: str _is_sleeping: bool = False + _last_training_mode: Literal["sft", "rl"] | None = None _latest_step: int = 0 _lora_id_counter: int = 1 # Start from 1 since 0 is reserved @@ -306,7 +308,6 @@ async def vllm_engine_is_sleeping(self) -> bool: async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: """Register a LoRA adapter for a specific checkpoint step. - This is called when training is skipped but the checkpoint is renamed. """ llm = await self.llm @@ -325,6 +326,26 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._latest_step = step await llm.resume_generation() + def _reset_optimizer_if_mode_changed( + self, + mode: Literal["sft", "rl"], + ) -> None: + """Reset optimizer state if training mode changed. + + Uses a single shared optimizer (trainer.optimizer) for both SFT and RL. + Resets optimizer state (momentum, variance) only when switching between + training modes to avoid stale state from a different loss landscape. + """ + mode_changed = ( + self._last_training_mode is not None and self._last_training_mode != mode + ) + + if mode_changed: + # Clear all optimizer state (exp_avg, exp_avg_sq, step for each param) + self._state.trainer.optimizer.state.clear() + + self._last_training_mode = mode + async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -356,6 +377,14 @@ async def train( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() + # Reset optimizer state if switching from SFT to RL + self._reset_optimizer_if_mode_changed("rl") + + # Set RL-specific hyperparameters + rl_weight_decay = 0.1 + for param_group in self._state.trainer.optimizer.param_groups: + param_group["weight_decay"] = rl_weight_decay + # Load packed tensors packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) @@ -434,6 +463,178 @@ async def train( if verbose: print("UnslothService.train complete") + # ========================================================================= + # SFT training + # ========================================================================= + + async def train_sft( + self, + batches: list[SFTBatch], + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train using SFT on pre-computed batches. + + Args: + batches: List of SFTBatch objects to train on. + verbose: Whether to print detailed logs. + + Yields: + Dictionary containing training metrics for each batch. + """ + import time + + llm = await self.llm + + # === Setup === + # Pause generation to prevent new requests during training + await llm.pause_generation() + + # Determine sleep level based on outstanding requests + has_unfinished = llm.output_processor.has_unfinished_requests() + if has_unfinished: + sleep_level = 1 + else: + await llm.reset_prefix_cache() + sleep_level = 2 + + # Put workers to sleep + await run_on_workers(llm, do_sleep, level=sleep_level) + self._is_sleeping = True + gc_and_empty_cuda_cache() + + # Reload training model to GPU (after vLLM is asleep) + self._state.reload_to_gpu() + + # Get model and optimizer + peft_model = self._state.peft_model + self._reset_optimizer_if_mode_changed("sft") + optimizer = self._state.trainer.optimizer + + # Set SFT-specific hyperparameters + sft_weight_decay = 0.01 + for param_group in optimizer.param_groups: + param_group["weight_decay"] = sft_weight_decay + + # Reset environment variable that may be set by RL training + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" + + peft_model.train() + device = next(peft_model.parameters()).device + max_grad_norm = 1.0 + + if verbose: + print("SFT training started") + + # === Process batches === + batch_idx = 0 + for batch in batches: + batch_start_time = time.perf_counter() + batch_loss = 0.0 + + # Update learning rate for this batch + for param_group in optimizer.param_groups: + param_group["lr"] = batch.learning_rate + + # Total trainable tokens for loss normalization + num_items_in_batch = torch.tensor( + batch.num_trainable_tokens, dtype=torch.long, device=device + ) + + # Process each trajectory in the batch (gradient accumulation) + for trajectory_tensor in batch.trajectory_tensors: + # Move tensors to device + input_ids = trajectory_tensor["input_ids"].to(device) + attention_mask = trajectory_tensor["attention_mask"].to(device) + labels = trajectory_tensor["labels"].to(device) + + # Forward pass with num_items_in_batch for proper loss normalization + outputs = peft_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + num_items_in_batch=num_items_in_batch, + ) + + loss = outputs.loss + + # Backward pass - accumulate gradients + loss.backward() + + # Track metrics + batch_loss += loss.item() + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_( + peft_model.parameters(), max_grad_norm + ).item() + + # Optimizer step at the end of each batch + optimizer.step() + optimizer.zero_grad() + + # Compute timing metrics + batch_time = time.perf_counter() - batch_start_time + tokens_per_second = ( + batch.num_trainable_tokens / batch_time if batch_time > 0 else 0.0 + ) + + if verbose: + print( + f"Batch {batch_idx}: loss={batch_loss:.4f}, lr={batch.learning_rate:.2e}, " + f"grad_norm={grad_norm:.4f}, tok/s={tokens_per_second:.1f}" + ) + + batch_idx += 1 + + yield { + "loss": batch_loss, + "learning_rate": batch.learning_rate, + "grad_norm": grad_norm, + "num_trajectories": float(batch.num_trajectories), + "num_trainable_tokens": float(batch.num_trainable_tokens), + "tokens_per_second": tokens_per_second, + } + + # === Cleanup === + # Save checkpoint after training + checkpoint_dir = save_checkpoint( + trainer=self._state.trainer, + output_dir=self.output_dir, + verbose=verbose, + ) + + # Offload training model to CPU before waking vLLM + self._state.offload_to_cpu() + + # Free memory before waking up vLLM + gc_and_empty_cuda_cache() + await asyncio.sleep(0.5) + + # Wake up workers + await run_on_workers(llm, do_wake_up) + self._is_sleeping = False + + # Add the new LoRA adapter + new_step = int(os.path.basename(checkpoint_dir)) + added = await llm.add_lora( + LoRARequest( + lora_name=f"{self.model_name}@{new_step}", + lora_int_id=self._next_lora_id(), + lora_path=checkpoint_dir, + ) + ) + if not added: + raise RuntimeError( + f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}" + ) + self._latest_step = new_step + + # Resume generation after LoRA swap is complete + await llm.resume_generation() + + if verbose: + print("SFT training finished") + @cached_property def _state(self) -> UnslothState: import unsloth @@ -476,6 +677,10 @@ def _state(self) -> UnslothState: processing_class=tokenizer, ) + # Initialize optimizer eagerly using trainer's configured settings. + if trainer.optimizer is None: + trainer.create_optimizer() + # Initialize queues inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue() results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() diff --git a/src/art/utils/model_config.py b/src/art/utils/model_config.py new file mode 100644 index 000000000..a3d364aa9 --- /dev/null +++ b/src/art/utils/model_config.py @@ -0,0 +1,100 @@ +"""Model-specific configuration for chat templates and training defaults.""" + +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + """Configuration for a specific model's chat template.""" + + instruction_part: str + response_part: str + + +# Explicit model configurations for models that can't be auto-detected. +# Models not listed here will fall back to auto-detection from the tokenizer's chat_template. +MODEL_CONFIGS: dict[str, ModelConfig] = { + # Qwen3 with thinking disabled - always includes empty tags + "OpenPipe/Qwen3-14B-Instruct": ModelConfig( + instruction_part="<|im_start|>user\n", + response_part="<|im_start|>assistant\n\n\n\n\n", + ), +} + + +def detect_chat_template_parts( + tokenizer: object, +) -> tuple[str, str]: + """Detect instruction and response parts from a tokenizer's chat template. + + Args: + tokenizer: A tokenizer with a chat_template attribute + + Returns: + Tuple of (instruction_part, response_part) + + Raises: + ValueError: If the tokenizer has no chat_template or the format is unrecognized + """ + template = getattr(tokenizer, "chat_template", None) + if not template or not isinstance(template, str): + raise ValueError( + "Cannot detect chat template parts: tokenizer has no chat_template attribute. " + "Please specify instruction_part and response_part manually." + ) + + # ChatML format (Qwen, etc.) + if "<|im_start|>" in template: + return "<|im_start|>user\n", "<|im_start|>assistant\n" + + # Llama 3 format + if "<|start_header_id|>" in template: + return ( + "<|start_header_id|>user<|end_header_id|>\n\n", + "<|start_header_id|>assistant<|end_header_id|>\n\n", + ) + + # Gemma format + if "" in template: + return "user\n", "model\n" + + # Mistral format + if "[INST]" in template: + return "[INST]", "[/INST]" + + raise ValueError( + f"Unrecognized chat template format. " + f"Please specify instruction_part and response_part manually. " + f"Template starts with: {template[:100]!r}..." + ) + + +def get_instruction_response_parts( + model_id: str, + tokenizer: object, +) -> tuple[str, str]: + """Get instruction and response parts for a model. + + Checks for explicit model configuration first, then falls back to + auto-detection from the tokenizer's chat template. + + Args: + model_id: The model identifier + tokenizer: Tokenizer with chat_template attribute + + Returns: + Tuple of (instruction_part, response_part) + + Raises: + ValueError: If chat template cannot be detected and model has no explicit config + """ + # Check for explicit model configuration first + if model_id in MODEL_CONFIGS: + config = MODEL_CONFIGS[model_id] + return config.instruction_part, config.response_part + + # Fall back to auto-detection + try: + return detect_chat_template_parts(tokenizer) + except ValueError as e: + raise ValueError(f"Failed to detect chat template for {model_id}: {e}") from e diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py new file mode 100644 index 000000000..298301bc1 --- /dev/null +++ b/src/art/utils/sft.py @@ -0,0 +1,442 @@ +"""Utilities for supervised fine-tuning (SFT).""" + +import itertools +import json +import math +import random +from typing import TYPE_CHECKING, Generator, List, Literal, NamedTuple + +if TYPE_CHECKING: + from art.dev import TrainSFTConfig as DevTrainSFTConfig + from art.model import TrainableModel + from art.trajectories import Trajectory + from art.types import TrainSFTConfig + + +class SFTChunk(NamedTuple): + trajectories: "list[Trajectory]" + config: "TrainSFTConfig" + step: int + epoch: int + epoch_step: int + + +def _parse_jsonl_line(line: str) -> "Trajectory": + """Parse a JSONL line into a Trajectory object. + + Args: + line: A JSON string containing trajectory data with 'messages' and optional 'tools'. + + Returns: + A Trajectory object with the parsed data. + """ + from art.trajectories import Trajectory + + data = json.loads(line) + return Trajectory( + messages_and_choices=data.get("messages", []), + tools=data.get("tools"), + ) + + +def get_file_row_count(file_path: str) -> int: + """ + Count the number of non-empty rows in a JSONL file. + + Args: + file_path: Path to JSONL file + + Returns: + Number of non-empty lines in the file + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + count = get_file_row_count("data.jsonl") + print(f"Dataset has {count} items") + """ + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + count = 0 + with open(file_path, "r") as f: + for line in f: + if line.strip(): + count += 1 + return count + + +def create_lr_schedule( + total_steps: int, + peak_lr: float, + method: Literal["cosine", "linear", "constant"] = "linear", + warmup_steps: int = 0, + min_lr: float = 0.0, +) -> List[float]: + """ + Create learning rate schedule for training with optional warmup. + + Args: + total_steps: Total number of training steps + peak_lr: Peak learning rate + method: Learning rate schedule method. Options: + - "cosine": Cosine annealing from peak_lr to min_lr + - "linear": Linear decay from peak_lr to min_lr + - "constant": Constant learning rate (peak_lr for all steps) + warmup_steps: Number of warmup steps (linear warmup from 0 to peak_lr) + min_lr: Minimum learning rate (floor for decay schedules) + + Returns: + List of learning rates for each step + + Example: + # Cosine schedule with warmup + lrs = create_lr_schedule(100, 1e-4, method="cosine", warmup_steps=10) + + # Use with training loop + for step, chunk in enumerate(chunk_trajectories(...)): + train_sft(chunk, learning_rate=lrs[step]) + """ + if total_steps <= 0: + return [] + + learning_rates = [] + decay_steps = total_steps - warmup_steps + + for step in range(total_steps): + if step < warmup_steps: + # Warmup: linear ramp from min_lr to peak_lr + # Use (step + 1) so first step has lr > 0 + lr = min_lr + (peak_lr - min_lr) * ((step + 1) / warmup_steps) + else: + # Decay phase: progress goes from 0 to 1 + progress = ( + (step - warmup_steps) / (decay_steps - 1) if decay_steps > 1 else 0 + ) + if method == "cosine": + lr = min_lr + (peak_lr - min_lr) * 0.5 * ( + 1 + math.cos(math.pi * progress) + ) + elif method == "linear": + lr = peak_lr - (peak_lr - min_lr) * progress + elif method == "constant": + lr = peak_lr + else: + raise ValueError( + f"Unknown method: {method}. Choose from: cosine, linear, constant" + ) + + learning_rates.append(lr) + + return learning_rates + + +def create_sft_dataset_iterator( + trajectories: "list[Trajectory]", + chunk_size: int = 10, + epochs: int = 1, + batch_size: int = 2, + peak_lr: float = 2e-4, + schedule_type: Literal["cosine", "linear", "constant"] = "linear", + warmup_ratio: float = 0.1, + shuffle: bool = True, + seed: int = 42, + initial_step: int = 0, + show_progress: bool = True, +) -> "Generator[SFTChunk, None, None]": + """ + Prepare trajectories in chunks for multiple model.train_sft() calls. + + Yields SFTChunk objects so that each call to model.train_sft() produces + its own training metrics. The learning rate schedule is computed over the + entire dataset, then sliced so that scheduling (warmup, decay) is correct + across all chunks. + + Args: + trajectories: List of trajectories to train on. + chunk_size: Number of batches to process per train_sft call. Default: 10. + This is an internal optimization parameter and does not affect training. + epochs: Number of times to repeat the dataset. Default: 1 + batch_size: Number of trajectories per batch. Default: 2 + peak_lr: Peak learning rate. Default: 2e-4 + schedule_type: LR schedule ("cosine", "linear", "constant"). Default: "linear" + warmup_ratio: Fraction of total steps used for warmup. Default: 0.1 + shuffle: Whether to shuffle trajectories each epoch. Default: True + seed: Random seed. Each epoch uses seed + epoch_number. Default: 42 + initial_step: Global batch step to resume from. Default: 0 + show_progress: Whether to display a tqdm progress bar. Default: True + + Yields: + SFTChunk(trajectories, config, step, epoch, epoch_step). + + Example: + for chunk in create_sft_dataset_iterator( + trajectories=my_trajectories, + chunk_size=10, + epochs=3, + batch_size=2, + peak_lr=2e-4, + ): + await model.train_sft(chunk.trajectories, chunk.config) + """ + from tqdm.auto import tqdm + + from art.types import TrainSFTConfig as SFTConfig + + if chunk_size < 1: + raise ValueError(f"chunk_size must be >= 1, got {chunk_size}") + + dataset_size = len(trajectories) + if dataset_size == 0: + return + + batches_per_epoch = math.ceil(dataset_size / batch_size) + total_batches = batches_per_epoch * epochs + warmup_steps = int(total_batches * warmup_ratio) + + # Compute full LR schedule across all data + full_schedule = create_lr_schedule( + total_steps=total_batches, + peak_lr=peak_lr, + method=schedule_type, + warmup_steps=warmup_steps, + ) + + # chunk_size is in batches; compute trajectory count per chunk + items_per_chunk = batch_size * chunk_size + chunks_per_epoch = math.ceil(dataset_size / items_per_chunk) + + # Convert initial_step (batch-based) to initial_chunk for skipping + initial_chunk = initial_step // chunk_size + + pbar = ( + tqdm( + initial=initial_step, total=total_batches, desc="SFT Training", unit="step" + ) + if show_progress + else None + ) + + for epoch in range(epochs): + epoch_trajs = list(trajectories) + if shuffle: + random.Random(seed + epoch).shuffle(epoch_trajs) + + for chunk_idx in range(chunks_per_epoch): + global_chunk_idx = epoch * chunks_per_epoch + chunk_idx + + # Skip chunks before initial_step + if global_chunk_idx < initial_chunk: + continue + + chunk_start = chunk_idx * items_per_chunk + chunk_end = min(chunk_start + items_per_chunk, dataset_size) + chunk_trajs = epoch_trajs[chunk_start:chunk_end] + + num_batches_in_chunk = math.ceil(len(chunk_trajs) / batch_size) + global_batch_step = epoch * batches_per_epoch + (chunk_start // batch_size) + epoch_batch_step = chunk_start // batch_size + + chunk_lrs = full_schedule[ + global_batch_step : global_batch_step + num_batches_in_chunk + ] + + config = SFTConfig( + learning_rate=chunk_lrs, + batch_size=batch_size, + ) + + yield SFTChunk( + trajectories=chunk_trajs, + config=config, + step=global_batch_step, + epoch=epoch, + epoch_step=epoch_batch_step, + ) + + if pbar: + pbar.update(num_batches_in_chunk) + + if pbar: + pbar.close() + + +def iterate_file( + file_path: str, + epochs: int = 1, + shuffle_buffer_size: int = 10000, + seed: int = 42, + initial_skip: int = 0, +) -> Generator["Trajectory", None, None]: + """ + Stream trajectories from a JSONL file for one or more epochs. + + Uses buffer-based shuffling to randomize order without loading all data + into memory. Each epoch uses a different seed for varied shuffling. + + Args: + file_path: Path to JSONL file (one JSON object per line) + epochs: Number of times to iterate over the file. Default: 1 + shuffle_buffer_size: Size of shuffle buffer. Default: 10000. + Larger values give better shuffling but use more memory. + seed: Base random seed. Each epoch uses seed + epoch_number. Default: 42 + initial_skip: Number of trajectories to skip (for resuming). Default: 0 + + Yields: + Trajectory objects + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + for trajectory in iterate_file("data.jsonl", epochs=3): + process(trajectory) + """ + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + skipped = 0 + + for epoch in range(epochs): + rng = random.Random(seed + epoch) + shuffle_buffer: List["Trajectory"] = [] + + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + traj = _parse_jsonl_line(line) + shuffle_buffer.append(traj) + + # Once buffer is full, start yielding randomly + if len(shuffle_buffer) >= shuffle_buffer_size: + idx = rng.randint(0, len(shuffle_buffer) - 1) + item = shuffle_buffer.pop(idx) + + if skipped < initial_skip: + skipped += 1 + else: + yield item + + # Flush remaining items in shuffle buffer + rng.shuffle(shuffle_buffer) + for traj in shuffle_buffer: + if skipped < initial_skip: + skipped += 1 + else: + yield traj + + +async def train_sft_from_file( + model: "TrainableModel", + file_path: str, + epochs: int = 1, + batch_size: int = 2, + peak_lr: float = 2e-4, + schedule_type: Literal["cosine", "linear", "constant"] = "linear", + warmup_ratio: float = 0.1, + initial_step: int = 0, + final_step: int | None = None, + _config: "DevTrainSFTConfig | None" = None, + verbose: bool = False, + shuffle_buffer_size: int = 10000, +) -> None: + """ + Train a model using supervised fine-tuning from a JSONL file. + + Streams data without loading all into memory. Suitable for large files (10GB+). + + Args: + model: The TrainableModel to fine-tune. Must be registered with a backend. + file_path: Path to JSONL file containing training data. Each line should have: + - messages: List of chat messages + - tools: Optional list of tools + epochs: Number of times to iterate over the dataset. Default: 1 + batch_size: Number of trajectories per batch. Default: 2 + peak_lr: Peak learning rate. Default: 2e-4 + schedule_type: Learning rate schedule ("cosine", "linear", "constant"). Default: "linear" + warmup_ratio: Ratio of total steps for warmup (0.0 to 1.0). Default: 0.1 + initial_step: Starting step for resuming training. Default: 0 + final_step: Ending step (exclusive). If None, trains to end of dataset. + Useful for breaking training into segments with benchmarks in between. + _config: Experimental configuration. Use at your own risk. + verbose: Whether to print verbose output. Default: False + shuffle_buffer_size: Size of shuffle buffer. Default: 10000. + Larger values give better shuffling but use more memory. + + Example: + await train_sft_from_file( + model=model, + file_path="data/train.jsonl", + epochs=3, + batch_size=4, + peak_lr=2e-4, + ) + """ + from art.types import TrainSFTConfig + + row_count = get_file_row_count(file_path) + + if verbose: + print(f"File has {row_count} rows") + + if row_count == 0: + if verbose: + print("No trajectories to train on") + return + + # Calculate total trajectories and batches + total_trajectories = row_count * epochs + skip_trajectories = initial_step * batch_size + + if skip_trajectories >= total_trajectories: + if verbose: + print(f"initial_step {initial_step} skips all trajectories") + return + + total_batches = math.ceil(total_trajectories / batch_size) + warmup_steps = int(total_batches * warmup_ratio) + + if final_step is not None and final_step > total_batches: + final_step = total_batches + + # Create learning rate schedule + full_schedule = create_lr_schedule( + total_steps=total_batches, + peak_lr=peak_lr, + method=schedule_type, + warmup_steps=warmup_steps, + ) + learning_rates = full_schedule[initial_step:final_step] + + if verbose: + num_training_trajectories = len(learning_rates) * batch_size + print(f"Training {num_training_trajectories} trajectories") + print(f"Batches: {len(learning_rates)}, batch_size: {batch_size}") + print(f"Schedule: {schedule_type}, peak_lr: {peak_lr}") + + # Stream trajectories from file, capped to the number we need + max_trajectories = len(learning_rates) * batch_size + trajectories = itertools.islice( + iterate_file( + file_path=file_path, + epochs=epochs, + shuffle_buffer_size=shuffle_buffer_size, + initial_skip=skip_trajectories, + ), + max_trajectories, + ) + + config = TrainSFTConfig( + learning_rate=learning_rates, + batch_size=batch_size, + ) + + await model.train_sft( + trajectories, + config, + _config=_config, + verbose=verbose, + ) diff --git a/tests/unit/test_frontend_logging.py b/tests/unit/test_frontend_logging.py index eb0a1c595..202785892 100644 --- a/tests/unit/test_frontend_logging.py +++ b/tests/unit/test_frontend_logging.py @@ -573,3 +573,128 @@ def test_report_metrics_custom(self): """Verify report_metrics can be customized.""" model = Model(name="test", project="test", report_metrics=["wandb", "custom"]) assert model.report_metrics == ["wandb", "custom"] + + +class TestTrainSFTMetricsAggregation: + """Test that train_sft aggregates metrics and logs once (same as RL).""" + + @pytest.mark.asyncio + async def test_train_sft_aggregates_metrics(self, tmp_path: Path): + """Verify train_sft aggregates metrics from multiple batches into one log entry.""" + model = TrainableModel( + name="test-sft", + project="test-project", + base_model="Qwen/Qwen2.5-0.5B-Instruct", + base_path=str(tmp_path), + ) + + # Mock the backend to yield multiple batch metrics + mock_backend = MagicMock() + + async def mock_train_sft(*args, **kwargs): + # Simulate 3 batches with different metrics + yield {"loss": 1.0, "learning_rate": 1e-4, "grad_norm": 0.5} + yield {"loss": 0.8, "learning_rate": 1e-4, "grad_norm": 0.4} + yield {"loss": 0.6, "learning_rate": 1e-4, "grad_norm": 0.3} + + mock_backend._train_sft = mock_train_sft + mock_backend._get_step = AsyncMock(return_value=1) # Step after training + model._backend = mock_backend + + # Create dummy trajectories + trajectories = [ + Trajectory( + reward=0.0, + messages_and_choices=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ], + ) + for _ in range(3) + ] + + # Run train_sft + await model.train_sft(trajectories) + + # Verify history.jsonl has exactly ONE entry (not 3) + history_path = tmp_path / "test-project/models/test-sft/history.jsonl" + assert history_path.exists(), "history.jsonl should be created" + + with open(history_path) as f: + lines = f.readlines() + + assert len(lines) == 1, f"Expected 1 log entry, got {len(lines)}" + + # Verify metrics are aggregated (averaged) + entry = json.loads(lines[0]) + assert entry["step"] == 1 + assert entry["train/loss"] == pytest.approx(0.8) # (1.0 + 0.8 + 0.6) / 3 + assert entry["train/grad_norm"] == pytest.approx(0.4) # (0.5 + 0.4 + 0.3) / 3 + + @pytest.mark.asyncio + async def test_train_sft_single_step_increment(self, tmp_path: Path): + """Verify train_sft results in single step increment regardless of batch count.""" + model = TrainableModel( + name="test-sft-step", + project="test-project", + base_model="gpt-4", + base_path=str(tmp_path), + ) + + mock_backend = MagicMock() + + async def mock_train_sft(*args, **kwargs): + # Simulate 5 batches + for i in range(5): + yield {"loss": 1.0 - i * 0.1} + + mock_backend._train_sft = mock_train_sft + mock_backend._get_step = AsyncMock(return_value=1) # Step is 1 after training + model._backend = mock_backend + + trajectories = [ + Trajectory( + reward=0.0, + messages_and_choices=[{"role": "user", "content": f"msg{i}"}], + ) + for i in range(10) + ] + + await model.train_sft(trajectories) + + # Verify only one log entry at step 1 + history_path = tmp_path / "test-project/models/test-sft-step/history.jsonl" + df = pl.read_ndjson(str(history_path)) + + assert len(df) == 1, "Should have exactly 1 log entry" + assert df["step"][0] == 1, "Step should be 1 (single increment)" + + @pytest.mark.asyncio + async def test_train_sft_no_metrics_when_empty(self, tmp_path: Path): + """Verify train_sft handles empty training gracefully.""" + model = TrainableModel( + name="test-sft-empty", + project="test-project", + base_model="gpt-4", + base_path=str(tmp_path), + ) + + mock_backend = MagicMock() + + async def mock_train_sft(*args, **kwargs): + # No batches yielded (empty training) + return + yield # Make it a generator + + mock_backend._train_sft = mock_train_sft + model._backend = mock_backend + + trajectories = [] + + await model.train_sft(trajectories) + + # Verify no history.jsonl created (no metrics to log) + history_path = tmp_path / "test-project/models/test-sft-empty/history.jsonl" + assert not history_path.exists(), ( + "No history.jsonl should be created for empty training" + ) diff --git a/tests/unit/test_multi_checkpoint_inference.py b/tests/unit/test_multi_checkpoint_inference.py index dadaf09a4..cd91e2f31 100644 --- a/tests/unit/test_multi_checkpoint_inference.py +++ b/tests/unit/test_multi_checkpoint_inference.py @@ -6,14 +6,13 @@ The key features tested are: 1. Model.get_inference_name() with optional step parameter -2. TinkerState.get_sampler_client() for step-based routing -3. ServerlessBackend._model_inference_name() with step suffix -4. UnslothService max_loras configuration +2. ServerlessBackend._model_inference_name() with step suffix +3. UnslothService max_loras configuration """ import asyncio from dataclasses import dataclass -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -126,87 +125,6 @@ def test_litellm_completion_params_trainable_model_with_step(self): assert params_with_step["model"] == "hosted_vllm/trainable-model@3" -# ============================================================================= -# TinkerState Tests -# ============================================================================= - - -class TestTinkerStateGetSamplerClient: - """Test TinkerState.get_sampler_client() for step-based routing.""" - - @pytest.fixture - def tinker_state_class(self): - """Import TinkerState, skipping if dependencies unavailable.""" - try: - from art.tinker.service import TinkerState - - return TinkerState - except ImportError as e: - pytest.skip(f"Tinker dependencies not available: {e}") - - def test_get_sampler_client_without_step_returns_latest(self, tinker_state_class): - """Without step, should return client for latest_step.""" - TinkerState = tinker_state_class - - # Create mock sampler clients - mock_client_0 = MagicMock() - mock_client_5 = MagicMock() - - state = TinkerState( - service_client=MagicMock(), - rest_client=MagicMock(), - training_client=MagicMock(), - sampler_clients={0: mock_client_0, 5: mock_client_5}, - latest_step=5, - renderer=MagicMock(), - ) - - assert state.get_sampler_client() is mock_client_5 - assert state.get_sampler_client(step=None) is mock_client_5 - - def test_get_sampler_client_with_step_returns_specific_client( - self, tinker_state_class - ): - """With step, should return client for that specific step.""" - TinkerState = tinker_state_class - - mock_client_0 = MagicMock() - mock_client_3 = MagicMock() - mock_client_5 = MagicMock() - - state = TinkerState( - service_client=MagicMock(), - rest_client=MagicMock(), - training_client=MagicMock(), - sampler_clients={0: mock_client_0, 3: mock_client_3, 5: mock_client_5}, - latest_step=5, - renderer=MagicMock(), - ) - - assert state.get_sampler_client(step=0) is mock_client_0 - assert state.get_sampler_client(step=3) is mock_client_3 - assert state.get_sampler_client(step=5) is mock_client_5 - - def test_get_sampler_client_invalid_step_raises_error(self, tinker_state_class): - """Invalid step should raise ValueError with available steps.""" - TinkerState = tinker_state_class - - state = TinkerState( - service_client=MagicMock(), - rest_client=MagicMock(), - training_client=MagicMock(), - sampler_clients={0: MagicMock(), 5: MagicMock()}, - latest_step=5, - renderer=MagicMock(), - ) - - with pytest.raises(ValueError) as exc_info: - state.get_sampler_client(step=3) - - assert "No sampler client for step 3" in str(exc_info.value) - assert "Available steps: [0, 5]" in str(exc_info.value) - - # ============================================================================= # ServerlessBackend Tests # ============================================================================= diff --git a/tests/unit/test_sft.py b/tests/unit/test_sft.py new file mode 100644 index 000000000..3c57b2e07 --- /dev/null +++ b/tests/unit/test_sft.py @@ -0,0 +1,342 @@ +"""Unit tests for SFT utilities.""" + +import json +from pathlib import Path +import tempfile + +import pytest + +from art.utils.sft import ( + create_lr_schedule, + create_sft_dataset_iterator, + iterate_file, +) + + +# Helper to create a temporary JSONL file +def create_temp_jsonl(num_trajectories: int) -> Path: + """Create a temporary JSONL file with dummy trajectories.""" + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for i in range(num_trajectories): + data = { + "messages": [ + {"role": "user", "content": f"Message {i}"}, + {"role": "assistant", "content": f"Response {i}"}, + ], + } + temp_file.write(json.dumps(data) + "\n") + temp_file.close() + return Path(temp_file.name) + + +def test_iterate_file(): + """Test iterate_file reads trajectories correctly.""" + jsonl_file = create_temp_jsonl(10) + + try: + trajectories = list(iterate_file(str(jsonl_file), epochs=1)) + + assert len(trajectories) == 10 + + finally: + jsonl_file.unlink() + + +def test_iterate_file_multiple_epochs(): + """Test iterate_file with multiple epochs.""" + jsonl_file = create_temp_jsonl(10) + + try: + trajectories = list(iterate_file(str(jsonl_file), epochs=3)) + + # Should have 30 trajectories (10 * 3 epochs) + assert len(trajectories) == 30 + + finally: + jsonl_file.unlink() + + +def test_iterate_file_with_initial_skip(): + """Test iterate_file with initial_skip for resuming.""" + jsonl_file = create_temp_jsonl(10) + + try: + # Skip first 5 trajectories + trajectories = list(iterate_file(str(jsonl_file), epochs=1, initial_skip=5)) + + assert len(trajectories) == 5 + + finally: + jsonl_file.unlink() + + +def test_iterate_file_deterministic(): + """Test that iterate_file is deterministic with same seed.""" + jsonl_file = create_temp_jsonl(20) + + try: + traj1 = list(iterate_file(str(jsonl_file), epochs=1, seed=42)) + traj2 = list(iterate_file(str(jsonl_file), epochs=1, seed=42)) + + # Should get same order + for t1, t2 in zip(traj1, traj2): + assert t1.messages_and_choices == t2.messages_and_choices + + finally: + jsonl_file.unlink() + + +def test_lr_schedule_warmup_not_zero(): + """Test that warmup doesn't start at lr=0.""" + lrs = create_lr_schedule( + total_steps=10, + peak_lr=1e-4, + method="constant", + warmup_steps=5, + min_lr=0.0, + ) + + # First step should NOT be 0 + assert lrs[0] > 0 + # Should reach peak_lr by end of warmup + assert lrs[4] == pytest.approx(1e-4) + # After warmup, should stay at peak_lr (constant schedule) + assert lrs[5] == pytest.approx(1e-4) + + +def test_lr_schedule_edge_cases(): + """Test LR schedule edge cases.""" + # Empty schedule + lrs = create_lr_schedule(total_steps=0, peak_lr=1e-4) + assert lrs == [] + + # Single step + lrs = create_lr_schedule(total_steps=1, peak_lr=1e-4) + assert len(lrs) == 1 + assert lrs[0] == pytest.approx(1e-4) + + # Warmup steps >= total_steps (edge case) + lrs = create_lr_schedule(total_steps=5, peak_lr=1e-4, warmup_steps=10) + assert len(lrs) == 5 + # Should not crash and should produce valid learning rates + assert all(lr > 0 for lr in lrs) + + +def test_lr_schedule_decay_methods(): + """Test that cosine and linear decay work correctly.""" + peak_lr = 1e-4 + min_lr = 1e-5 + + # Linear decay: should go from peak_lr to min_lr + lrs = create_lr_schedule( + total_steps=5, peak_lr=peak_lr, method="linear", min_lr=min_lr + ) + assert lrs[0] == pytest.approx(peak_lr) # Start at peak + assert lrs[-1] == pytest.approx(min_lr) # End at min + # Should be monotonically decreasing + for i in range(len(lrs) - 1): + assert lrs[i] >= lrs[i + 1] + + # Cosine decay: should go from peak_lr to min_lr + lrs = create_lr_schedule( + total_steps=5, peak_lr=peak_lr, method="cosine", min_lr=min_lr + ) + assert lrs[0] == pytest.approx(peak_lr) # Start at peak + assert lrs[-1] == pytest.approx(min_lr) # End at min + + +def test_lr_schedule_no_warmup(): + """Test schedule with warmup_steps=0.""" + lrs = create_lr_schedule( + total_steps=5, peak_lr=1e-4, method="linear", warmup_steps=0, min_lr=0 + ) + assert len(lrs) == 5 + assert lrs[0] == pytest.approx(1e-4) # Start at peak (no warmup) + assert lrs[-1] == pytest.approx(0) # End at min_lr + + +def _make_trajectories(n: int): + """Create n dummy trajectories.""" + from art.trajectories import Trajectory + + return [ + Trajectory( + messages_and_choices=[ + {"role": "user", "content": f"Message {i}"}, + {"role": "assistant", "content": f"Response {i}"}, + ], + ) + for i in range(n) + ] + + +def test_create_sft_dataset_iterator_lr_schedule_continuity(): + """Test that concatenated chunk LRs match the full schedule from create_lr_schedule.""" + trajs = _make_trajectories(100) + + # Compute the expected full schedule directly + import math + + total_batches = math.ceil(len(trajs) / 2) * 2 # 2 epochs + warmup_steps = int(total_batches * 0.1) + expected_lrs = create_lr_schedule( + total_steps=total_batches, + peak_lr=2e-4, + method="linear", + warmup_steps=warmup_steps, + ) + + chunks = list( + create_sft_dataset_iterator( + trajs, + chunk_size=15, + epochs=2, + batch_size=2, + peak_lr=2e-4, + seed=42, + show_progress=False, + ) + ) + + all_lrs: list[float] = [] + for chunk in chunks: + lr = chunk.config.learning_rate + if isinstance(lr, list): + all_lrs.extend(lr) + else: + all_lrs.append(lr) + + assert expected_lrs == all_lrs + + +def test_create_sft_dataset_iterator_step_tracking(): + """Test that step, epoch, and epoch_step are correct on each chunk.""" + trajs = _make_trajectories(20) + chunks = list( + create_sft_dataset_iterator( + trajs, + chunk_size=5, # 5 batches * 2 batch_size = 10 trajectories per chunk + epochs=2, + batch_size=2, + peak_lr=1e-4, + show_progress=False, + ) + ) + + # 20 trajs, chunk_size=5 batches -> 10 trajs/chunk -> 2 chunks per epoch, 2 epochs -> 4 chunks + assert len(chunks) == 4 + + assert chunks[0].step == 0 + assert chunks[0].epoch == 0 + assert chunks[0].epoch_step == 0 + + assert chunks[1].step == 5 # 10 trajs / batch_size 2 = 5 batches + assert chunks[1].epoch == 0 + assert chunks[1].epoch_step == 5 + + assert chunks[2].step == 10 + assert chunks[2].epoch == 1 + assert chunks[2].epoch_step == 0 + + assert chunks[3].step == 15 + assert chunks[3].epoch == 1 + assert chunks[3].epoch_step == 5 + + +def test_create_sft_dataset_iterator_initial_step(): + """Test that initial_step skips completed chunks.""" + trajs = _make_trajectories(100) + all_chunks = list( + create_sft_dataset_iterator( + trajs, + chunk_size=25, # 25 batches * 2 batch_size = 50 trajectories per chunk + epochs=1, + batch_size=2, + peak_lr=2e-4, + show_progress=False, + ) + ) + + # Resume from step 25 (after first chunk of 25 batches) + resumed_chunks = list( + create_sft_dataset_iterator( + trajs, + chunk_size=25, + epochs=1, + batch_size=2, + peak_lr=2e-4, + initial_step=25, + show_progress=False, + ) + ) + + assert len(all_chunks) == 2 + assert len(resumed_chunks) == 1 + # Resumed chunk should have the same LRs as the second full chunk + assert resumed_chunks[0].config.learning_rate == all_chunks[1].config.learning_rate + + +def test_create_sft_dataset_iterator_deterministic(): + """Test that create_sft_dataset_iterator is deterministic with the same seed.""" + trajs = _make_trajectories(50) + + chunks1 = list( + create_sft_dataset_iterator( + trajs, chunk_size=10, epochs=2, batch_size=2, seed=42, show_progress=False + ) + ) + chunks2 = list( + create_sft_dataset_iterator( + trajs, chunk_size=10, epochs=2, batch_size=2, seed=42, show_progress=False + ) + ) + + assert len(chunks1) == len(chunks2) + for c1, c2 in zip(chunks1, chunks2): + assert c1.config.learning_rate == c2.config.learning_rate + assert c1.step == c2.step + for t1, t2 in zip(c1.trajectories, c2.trajectories): + assert t1.messages_and_choices == t2.messages_and_choices + + +def test_create_sft_dataset_iterator_empty_input(): + """Test that empty trajectories yields no chunks.""" + chunks = list(create_sft_dataset_iterator([], chunk_size=10, show_progress=False)) + assert chunks == [] + + +def test_create_sft_dataset_iterator_single_chunk(): + """Test that chunk_size >= dataset produces one chunk with full schedule.""" + import math + + trajs = _make_trajectories(10) + + chunks = list( + create_sft_dataset_iterator( + trajs, + chunk_size=100, # larger than dataset + epochs=1, + batch_size=2, + peak_lr=1e-4, + seed=42, + show_progress=False, + ) + ) + + total_batches = math.ceil(10 / 2) + warmup_steps = int(total_batches * 0.1) + expected_lrs = create_lr_schedule( + total_steps=total_batches, + peak_lr=1e-4, + method="linear", + warmup_steps=warmup_steps, + ) + + assert len(chunks) == 1 + assert chunks[0].config.learning_rate == expected_lrs + assert chunks[0].config.batch_size == 2 + assert len(chunks[0].trajectories) == 10 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/uv.lock b/uv.lock index c357a83dc..315758e3c 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,9 @@ resolution-markers = [ "python_full_version < '3.12' and sys_platform != 'linux'", ] +[manifest] +overrides = [{ name = "numpy", specifier = "<2" }] + [[package]] name = "abnf" version = "2.2.0" @@ -4172,50 +4175,26 @@ wheels = [ [[package]] name = "numpy" -version = "2.2.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/da/a8/4f83e2aa666a9fbf56d6118faaaf5f1974d456b1823fda0a176eff722839/numpy-2.2.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f9f1adb22318e121c5c69a09142811a201ef17ab257a1e66ca3025065b7f53ae", size = 21176963, upload-time = "2025-05-17T21:31:19.36Z" }, - { url = "https://files.pythonhosted.org/packages/b3/2b/64e1affc7972decb74c9e29e5649fac940514910960ba25cd9af4488b66c/numpy-2.2.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c820a93b0255bc360f53eca31a0e676fd1101f673dda8da93454a12e23fc5f7a", size = 14406743, upload-time = "2025-05-17T21:31:41.087Z" }, - { url = "https://files.pythonhosted.org/packages/4a/9f/0121e375000b5e50ffdd8b25bf78d8e1a5aa4cca3f185d41265198c7b834/numpy-2.2.6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3d70692235e759f260c3d837193090014aebdf026dfd167834bcba43e30c2a42", size = 5352616, upload-time = "2025-05-17T21:31:50.072Z" }, - { url = "https://files.pythonhosted.org/packages/31/0d/b48c405c91693635fbe2dcd7bc84a33a602add5f63286e024d3b6741411c/numpy-2.2.6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:481b49095335f8eed42e39e8041327c05b0f6f4780488f61286ed3c01368d491", size = 6889579, upload-time = "2025-05-17T21:32:01.712Z" }, - { url = "https://files.pythonhosted.org/packages/52/b8/7f0554d49b565d0171eab6e99001846882000883998e7b7d9f0d98b1f934/numpy-2.2.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b64d8d4d17135e00c8e346e0a738deb17e754230d7e0810ac5012750bbd85a5a", size = 14312005, upload-time = "2025-05-17T21:32:23.332Z" }, - { url = "https://files.pythonhosted.org/packages/b3/dd/2238b898e51bd6d389b7389ffb20d7f4c10066d80351187ec8e303a5a475/numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba10f8411898fc418a521833e014a77d3ca01c15b0c6cdcce6a0d2897e6dbbdf", size = 16821570, upload-time = "2025-05-17T21:32:47.991Z" }, - { url = "https://files.pythonhosted.org/packages/83/6c/44d0325722cf644f191042bf47eedad61c1e6df2432ed65cbe28509d404e/numpy-2.2.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:bd48227a919f1bafbdda0583705e547892342c26fb127219d60a5c36882609d1", size = 15818548, upload-time = "2025-05-17T21:33:11.728Z" }, - { url = "https://files.pythonhosted.org/packages/ae/9d/81e8216030ce66be25279098789b665d49ff19eef08bfa8cb96d4957f422/numpy-2.2.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9551a499bf125c1d4f9e250377c1ee2eddd02e01eac6644c080162c0c51778ab", size = 18620521, upload-time = "2025-05-17T21:33:39.139Z" }, - { url = "https://files.pythonhosted.org/packages/6a/fd/e19617b9530b031db51b0926eed5345ce8ddc669bb3bc0044b23e275ebe8/numpy-2.2.6-cp311-cp311-win32.whl", hash = "sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47", size = 6525866, upload-time = "2025-05-17T21:33:50.273Z" }, - { url = "https://files.pythonhosted.org/packages/31/0a/f354fb7176b81747d870f7991dc763e157a934c717b67b58456bc63da3df/numpy-2.2.6-cp311-cp311-win_amd64.whl", hash = "sha256:e8213002e427c69c45a52bbd94163084025f533a55a59d6f9c5b820774ef3303", size = 12907455, upload-time = "2025-05-17T21:34:09.135Z" }, - { url = "https://files.pythonhosted.org/packages/82/5d/c00588b6cf18e1da539b45d3598d3557084990dcc4331960c15ee776ee41/numpy-2.2.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff", size = 20875348, upload-time = "2025-05-17T21:34:39.648Z" }, - { url = "https://files.pythonhosted.org/packages/66/ee/560deadcdde6c2f90200450d5938f63a34b37e27ebff162810f716f6a230/numpy-2.2.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c", size = 14119362, upload-time = "2025-05-17T21:35:01.241Z" }, - { url = "https://files.pythonhosted.org/packages/3c/65/4baa99f1c53b30adf0acd9a5519078871ddde8d2339dc5a7fde80d9d87da/numpy-2.2.6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3", size = 5084103, upload-time = "2025-05-17T21:35:10.622Z" }, - { url = "https://files.pythonhosted.org/packages/cc/89/e5a34c071a0570cc40c9a54eb472d113eea6d002e9ae12bb3a8407fb912e/numpy-2.2.6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282", size = 6625382, upload-time = "2025-05-17T21:35:21.414Z" }, - { url = "https://files.pythonhosted.org/packages/f8/35/8c80729f1ff76b3921d5c9487c7ac3de9b2a103b1cd05e905b3090513510/numpy-2.2.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87", size = 14018462, upload-time = "2025-05-17T21:35:42.174Z" }, - { url = "https://files.pythonhosted.org/packages/8c/3d/1e1db36cfd41f895d266b103df00ca5b3cbe965184df824dec5c08c6b803/numpy-2.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249", size = 16527618, upload-time = "2025-05-17T21:36:06.711Z" }, - { url = "https://files.pythonhosted.org/packages/61/c6/03ed30992602c85aa3cd95b9070a514f8b3c33e31124694438d88809ae36/numpy-2.2.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49", size = 15505511, upload-time = "2025-05-17T21:36:29.965Z" }, - { url = "https://files.pythonhosted.org/packages/b7/25/5761d832a81df431e260719ec45de696414266613c9ee268394dd5ad8236/numpy-2.2.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de", size = 18313783, upload-time = "2025-05-17T21:36:56.883Z" }, - { url = "https://files.pythonhosted.org/packages/57/0a/72d5a3527c5ebffcd47bde9162c39fae1f90138c961e5296491ce778e682/numpy-2.2.6-cp312-cp312-win32.whl", hash = "sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4", size = 6246506, upload-time = "2025-05-17T21:37:07.368Z" }, - { url = "https://files.pythonhosted.org/packages/36/fa/8c9210162ca1b88529ab76b41ba02d433fd54fecaf6feb70ef9f124683f1/numpy-2.2.6-cp312-cp312-win_amd64.whl", hash = "sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2", size = 12614190, upload-time = "2025-05-17T21:37:26.213Z" }, - { url = "https://files.pythonhosted.org/packages/f9/5c/6657823f4f594f72b5471f1db1ab12e26e890bb2e41897522d134d2a3e81/numpy-2.2.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0811bb762109d9708cca4d0b13c4f67146e3c3b7cf8d34018c722adb2d957c84", size = 20867828, upload-time = "2025-05-17T21:37:56.699Z" }, - { url = "https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:287cc3162b6f01463ccd86be154f284d0893d2b3ed7292439ea97eafa8170e0b", size = 14143006, upload-time = "2025-05-17T21:38:18.291Z" }, - { url = "https://files.pythonhosted.org/packages/4f/06/7e96c57d90bebdce9918412087fc22ca9851cceaf5567a45c1f404480e9e/numpy-2.2.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:f1372f041402e37e5e633e586f62aa53de2eac8d98cbfb822806ce4bbefcb74d", size = 5076765, upload-time = "2025-05-17T21:38:27.319Z" }, - { url = "https://files.pythonhosted.org/packages/73/ed/63d920c23b4289fdac96ddbdd6132e9427790977d5457cd132f18e76eae0/numpy-2.2.6-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:55a4d33fa519660d69614a9fad433be87e5252f4b03850642f88993f7b2ca566", size = 6617736, upload-time = "2025-05-17T21:38:38.141Z" }, - { url = "https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92729c95468a2f4f15e9bb94c432a9229d0d50de67304399627a943201baa2f", size = 14010719, upload-time = "2025-05-17T21:38:58.433Z" }, - { url = "https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bc23a79bfabc5d056d106f9befb8d50c31ced2fbc70eedb8155aec74a45798f", size = 16526072, upload-time = "2025-05-17T21:39:22.638Z" }, - { url = "https://files.pythonhosted.org/packages/b2/6c/04b5f47f4f32f7c2b0e7260442a8cbcf8168b0e1a41ff1495da42f42a14f/numpy-2.2.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e3143e4451880bed956e706a3220b4e5cf6172ef05fcc397f6f36a550b1dd868", size = 15503213, upload-time = "2025-05-17T21:39:45.865Z" }, - { url = "https://files.pythonhosted.org/packages/17/0a/5cd92e352c1307640d5b6fec1b2ffb06cd0dabe7d7b8227f97933d378422/numpy-2.2.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b4f13750ce79751586ae2eb824ba7e1e8dba64784086c98cdbbcc6a42112ce0d", size = 18316632, upload-time = "2025-05-17T21:40:13.331Z" }, - { url = "https://files.pythonhosted.org/packages/f0/3b/5cba2b1d88760ef86596ad0f3d484b1cbff7c115ae2429678465057c5155/numpy-2.2.6-cp313-cp313-win32.whl", hash = "sha256:5beb72339d9d4fa36522fc63802f469b13cdbe4fdab4a288f0c441b74272ebfd", size = 6244532, upload-time = "2025-05-17T21:43:46.099Z" }, - { url = "https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl", hash = "sha256:b0544343a702fa80c95ad5d3d608ea3599dd54d4632df855e4c8d24eb6ecfa1c", size = 12610885, upload-time = "2025-05-17T21:44:05.145Z" }, - { url = "https://files.pythonhosted.org/packages/6b/9e/4bf918b818e516322db999ac25d00c75788ddfd2d2ade4fa66f1f38097e1/numpy-2.2.6-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0bca768cd85ae743b2affdc762d617eddf3bcf8724435498a1e80132d04879e6", size = 20963467, upload-time = "2025-05-17T21:40:44Z" }, - { url = "https://files.pythonhosted.org/packages/61/66/d2de6b291507517ff2e438e13ff7b1e2cdbdb7cb40b3ed475377aece69f9/numpy-2.2.6-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:fc0c5673685c508a142ca65209b4e79ed6740a4ed6b2267dbba90f34b0b3cfda", size = 14225144, upload-time = "2025-05-17T21:41:05.695Z" }, - { url = "https://files.pythonhosted.org/packages/e4/25/480387655407ead912e28ba3a820bc69af9adf13bcbe40b299d454ec011f/numpy-2.2.6-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:5bd4fc3ac8926b3819797a7c0e2631eb889b4118a9898c84f585a54d475b7e40", size = 5200217, upload-time = "2025-05-17T21:41:15.903Z" }, - { url = "https://files.pythonhosted.org/packages/aa/4a/6e313b5108f53dcbf3aca0c0f3e9c92f4c10ce57a0a721851f9785872895/numpy-2.2.6-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8", size = 6712014, upload-time = "2025-05-17T21:41:27.321Z" }, - { url = "https://files.pythonhosted.org/packages/b7/30/172c2d5c4be71fdf476e9de553443cf8e25feddbe185e0bd88b096915bcc/numpy-2.2.6-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e1dda9c7e08dc141e0247a5b8f49cf05984955246a327d4c48bda16821947b2f", size = 14077935, upload-time = "2025-05-17T21:41:49.738Z" }, - { url = "https://files.pythonhosted.org/packages/12/fb/9e743f8d4e4d3c710902cf87af3512082ae3d43b945d5d16563f26ec251d/numpy-2.2.6-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f447e6acb680fd307f40d3da4852208af94afdfab89cf850986c3ca00562f4fa", size = 16600122, upload-time = "2025-05-17T21:42:14.046Z" }, - { url = "https://files.pythonhosted.org/packages/12/75/ee20da0e58d3a66f204f38916757e01e33a9737d0b22373b3eb5a27358f9/numpy-2.2.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:389d771b1623ec92636b0786bc4ae56abafad4a4c513d36a55dce14bd9ce8571", size = 15586143, upload-time = "2025-05-17T21:42:37.464Z" }, - { url = "https://files.pythonhosted.org/packages/76/95/bef5b37f29fc5e739947e9ce5179ad402875633308504a52d188302319c8/numpy-2.2.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8e9ace4a37db23421249ed236fdcdd457d671e25146786dfc96835cd951aa7c1", size = 18385260, upload-time = "2025-05-17T21:43:05.189Z" }, - { url = "https://files.pythonhosted.org/packages/09/04/f2f83279d287407cf36a7a8053a5abe7be3622a4363337338f2585e4afda/numpy-2.2.6-cp313-cp313t-win32.whl", hash = "sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff", size = 6377225, upload-time = "2025-05-17T21:43:16.254Z" }, - { url = "https://files.pythonhosted.org/packages/67/0e/35082d13c09c02c011cf21570543d202ad929d961c02a147493cb0c2bdf5/numpy-2.2.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06", size = 12771374, upload-time = "2025-05-17T21:43:35.479Z" }, +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/57/baae43d14fe163fa0e4c47f307b6b2511ab8d7d30177c491960504252053/numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71", size = 20630554, upload-time = "2024-02-05T23:51:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/1a/2e/151484f49fd03944c4a3ad9c418ed193cfd02724e138ac8a9505d056c582/numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef", size = 13997127, upload-time = "2024-02-05T23:52:15.314Z" }, + { url = "https://files.pythonhosted.org/packages/79/ae/7e5b85136806f9dadf4878bf73cf223fe5c2636818ba3ab1c585d0403164/numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e", size = 14222994, upload-time = "2024-02-05T23:52:47.569Z" }, + { url = "https://files.pythonhosted.org/packages/3a/d0/edc009c27b406c4f9cbc79274d6e46d634d139075492ad055e3d68445925/numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5", size = 18252005, upload-time = "2024-02-05T23:53:15.637Z" }, + { url = "https://files.pythonhosted.org/packages/09/bf/2b1aaf8f525f2923ff6cfcf134ae5e750e279ac65ebf386c75a0cf6da06a/numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a", size = 13885297, upload-time = "2024-02-05T23:53:42.16Z" }, + { url = "https://files.pythonhosted.org/packages/df/a0/4e0f14d847cfc2a633a1c8621d00724f3206cfeddeb66d35698c4e2cf3d2/numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a", size = 18093567, upload-time = "2024-02-05T23:54:11.696Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b7/a734c733286e10a7f1a8ad1ae8c90f2d33bf604a96548e0a4a3a6739b468/numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20", size = 5968812, upload-time = "2024-02-05T23:54:26.453Z" }, + { url = "https://files.pythonhosted.org/packages/3f/6b/5610004206cf7f8e7ad91c5a85a8c71b2f2f8051a0c0c4d5916b76d6cbb2/numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", size = 15811913, upload-time = "2024-02-05T23:54:53.933Z" }, + { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901, upload-time = "2024-02-05T23:55:32.801Z" }, + { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868, upload-time = "2024-02-05T23:55:56.28Z" }, + { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109, upload-time = "2024-02-05T23:56:20.368Z" }, + { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613, upload-time = "2024-02-05T23:56:56.054Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172, upload-time = "2024-02-05T23:57:21.56Z" }, + { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643, upload-time = "2024-02-05T23:57:56.585Z" }, + { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803, upload-time = "2024-02-05T23:58:08.963Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754, upload-time = "2024-02-05T23:58:36.364Z" }, ] [[package]] @@ -4579,8 +4558,8 @@ requires-dist = [ { name = "unsloth-zoo", marker = "extra == 'backend'", specifier = "==2025.12.7" }, { name = "uvicorn", marker = "extra == 'tinker'", specifier = ">=0.35.0" }, { name = "vllm", marker = "sys_platform == 'linux' and extra == 'backend'", specifier = "==0.15.1" }, - { name = "wandb", marker = "extra == 'backend'", specifier = "==0.23.1" }, - { name = "weave", specifier = ">=0.52.23" }, + { name = "wandb", marker = "extra == 'backend'", specifier = "==0.24.0" }, + { name = "weave", specifier = ">=0.52.24" }, ] provides-extras = ["plotting", "backend", "langgraph", "tinker"] @@ -8289,7 +8268,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.23.1" +version = "0.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -8304,17 +8283,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/cc/770ae3aa7ae44f6792f7ecb81c14c0e38b672deb35235719bb1006519487/wandb-0.23.1.tar.gz", hash = "sha256:f6fb1e3717949b29675a69359de0eeb01e67d3360d581947d5b3f98c273567d6", size = 44298053, upload-time = "2025-12-03T02:25:10.79Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/7e/aad6e943012ea4d88f3a037f1a5a7c6898263c60fbef8c9cdb95a8ff9fd9/wandb-0.24.0.tar.gz", hash = "sha256:4715a243b3d460b6434b9562e935dfd9dfdf5d6e428cfb4c3e7ce4fd44460ab3", size = 44197947, upload-time = "2026-01-13T22:59:59.767Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/0b/c3d7053dfd93fd259a63c7818d9c4ac2ba0642ff8dc8db98662ea0cf9cc0/wandb-0.23.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:358e15471d19b7d73fc464e37371c19d44d39e433252ac24df107aff993a286b", size = 21527293, upload-time = "2025-12-03T02:24:48.011Z" }, - { url = "https://files.pythonhosted.org/packages/ee/9f/059420fa0cb6c511dc5c5a50184122b6aca7b178cb2aa210139e354020da/wandb-0.23.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:110304407f4b38f163bdd50ed5c5225365e4df3092f13089c30171a75257b575", size = 22745926, upload-time = "2025-12-03T02:24:50.519Z" }, - { url = "https://files.pythonhosted.org/packages/96/b6/fd465827c14c64d056d30b4c9fcf4dac889a6969dba64489a88fc4ffa333/wandb-0.23.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6cc984cf85feb2f8ee0451d76bc9fb7f39da94956bb8183e30d26284cf203b65", size = 21212973, upload-time = "2025-12-03T02:24:52.828Z" }, - { url = "https://files.pythonhosted.org/packages/5c/ee/9a8bb9a39cc1f09c3060456cc79565110226dc4099a719af5c63432da21d/wandb-0.23.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:67431cd3168d79fdb803e503bd669c577872ffd5dadfa86de733b3274b93088e", size = 22887885, upload-time = "2025-12-03T02:24:55.281Z" }, - { url = "https://files.pythonhosted.org/packages/6d/4d/8d9e75add529142e037b05819cb3ab1005679272950128d69d218b7e5b2e/wandb-0.23.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:07be70c0baa97ea25fadc4a9d0097f7371eef6dcacc5ceb525c82491a31e9244", size = 21250967, upload-time = "2025-12-03T02:24:57.603Z" }, - { url = "https://files.pythonhosted.org/packages/97/72/0b35cddc4e4168f03c759b96d9f671ad18aec8bdfdd84adfea7ecb3f5701/wandb-0.23.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:216c95b08e0a2ec6a6008373b056d597573d565e30b43a7a93c35a171485ee26", size = 22988382, upload-time = "2025-12-03T02:25:00.518Z" }, - { url = "https://files.pythonhosted.org/packages/c0/6d/e78093d49d68afb26f5261a70fc7877c34c114af5c2ee0ab3b1af85f5e76/wandb-0.23.1-py3-none-win32.whl", hash = "sha256:fb5cf0f85692f758a5c36ab65fea96a1284126de64e836610f92ddbb26df5ded", size = 22150756, upload-time = "2025-12-03T02:25:02.734Z" }, - { url = "https://files.pythonhosted.org/packages/05/27/4f13454b44c9eceaac3d6e4e4efa2230b6712d613ff9bf7df010eef4fd18/wandb-0.23.1-py3-none-win_amd64.whl", hash = "sha256:21c8c56e436eb707b7d54f705652e030d48e5cfcba24cf953823eb652e30e714", size = 22150760, upload-time = "2025-12-03T02:25:05.106Z" }, - { url = "https://files.pythonhosted.org/packages/30/20/6c091d451e2a07689bfbfaeb7592d488011420e721de170884fedd68c644/wandb-0.23.1-py3-none-win_arm64.whl", hash = "sha256:8aee7f3bb573f2c0acf860f497ca9c684f9b35f2ca51011ba65af3d4592b77c1", size = 20137463, upload-time = "2025-12-03T02:25:08.317Z" }, + { url = "https://files.pythonhosted.org/packages/5f/8a/efec186dcc5dcf3c806040e3f33e58997878b2d30b87aa02b26f046858b6/wandb-0.24.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:aa9777398ff4b0f04c41359f7d1b95b5d656cb12c37c63903666799212e50299", size = 21464901, upload-time = "2026-01-13T22:59:31.86Z" }, + { url = "https://files.pythonhosted.org/packages/ed/84/fadf0d5f1d86c3ba662d2b33a15d2b1f08ff1e4e196c77e455f028b0fda2/wandb-0.24.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0423fbd58c3926949724feae8aab89d20c68846f9f4f596b80f9ffe1fc298130", size = 22697817, upload-time = "2026-01-13T22:59:35.267Z" }, + { url = "https://files.pythonhosted.org/packages/6e/5f/e3124e68d02b30c62856175ce714e07904730be06eecb00f66bb1a59aacf/wandb-0.24.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:2b25fc0c123daac97ed32912ac55642c65013cc6e3a898e88ca2d917fc8eadc0", size = 21118798, upload-time = "2026-01-13T22:59:38.453Z" }, + { url = "https://files.pythonhosted.org/packages/22/a1/8d68a914c030e897c306c876d47c73aa5d9ca72be608971290d3a5749570/wandb-0.24.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:9485344b4667944b5b77294185bae8469cfa4074869bec0e74f54f8492234cc2", size = 22849954, upload-time = "2026-01-13T22:59:41.265Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f8/3e68841a4282a4fb6a8935534e6064acc6c9708e8fb76953ec73bbc72a5e/wandb-0.24.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:51b2b9a9d7d6b35640f12a46a48814fd4516807ad44f586b819ed6560f8de1fd", size = 21160339, upload-time = "2026-01-13T22:59:43.967Z" }, + { url = "https://files.pythonhosted.org/packages/16/e5/d851868ce5b4b437a7cc90405979cd83809790e4e2a2f1e454f63f116e52/wandb-0.24.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:11f7e7841f31eff82c82a677988889ad3aa684c6de61ff82145333b5214ec860", size = 22936978, upload-time = "2026-01-13T22:59:46.911Z" }, + { url = "https://files.pythonhosted.org/packages/d2/34/43b7f18870051047ce6fe18e7eb24ba7ebdc71663a8f1c58e31e855eb8ac/wandb-0.24.0-py3-none-win32.whl", hash = "sha256:42af348998b00d4309ae790c5374040ac6cc353ab21567f4e29c98c9376dee8e", size = 22118243, upload-time = "2026-01-13T22:59:49.555Z" }, + { url = "https://files.pythonhosted.org/packages/a1/92/909c81173cf1399111f57f9ca5399a8f165607b024e406e080178c878f70/wandb-0.24.0-py3-none-win_amd64.whl", hash = "sha256:32604eddcd362e1ed4a2e2ce5f3a239369c4a193af223f3e66603481ac91f336", size = 22118246, upload-time = "2026-01-13T22:59:52.126Z" }, + { url = "https://files.pythonhosted.org/packages/87/85/a845aefd9c2285f98261fa6ffa0a14466366c1ac106d35bc84b654c0ad7f/wandb-0.24.0-py3-none-win_arm64.whl", hash = "sha256:e0f2367552abfca21b0f3a03405fbf48f1e14de9846e70f73c6af5da57afd8ef", size = 20077678, upload-time = "2026-01-13T22:59:56.112Z" }, ] [[package]] @@ -8454,7 +8433,7 @@ wheels = [ [[package]] name = "weave" -version = "0.52.23" +version = "0.52.24" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -8469,9 +8448,9 @@ dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, { name = "wandb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a2/3c/119b64e092218f7c37f9ca65bfda8ef856be6c03687e5701fa67fd16f3df/weave-0.52.23.tar.gz", hash = "sha256:ad4f37cc901cb93a000faedbe1313509e97c2e7a18cbaada574fef820741ff32", size = 647645, upload-time = "2026-01-08T18:23:46.71Z" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a9/479a4ebd918bf8d08ea1d8daf1a40ace4c6940f86472700036436c6acc78/weave-0.52.24.tar.gz", hash = "sha256:1d756c4b35054d8efba0300eb2a595f9c63d96601f1a3f37918e620742c13481", size = 650113, upload-time = "2026-01-15T20:46:41.771Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/06/489587dd6f5d06f19535ffbf6cd98e061eef3b50583d8dc40d79271c470b/weave-0.52.23-py3-none-any.whl", hash = "sha256:debe9eac5bdec857dc2507401f075b1b1f3199eccc9ef29f106747c13b23966c", size = 810639, upload-time = "2026-01-08T18:23:44.593Z" }, + { url = "https://files.pythonhosted.org/packages/c4/b1/b900fb5e81bbb143ef69ea49b2a2024357f6423511b24d14ab65dbe80272/weave-0.52.24-py3-none-any.whl", hash = "sha256:afe79bb5326aaf22751f2bae20878354062794c3577ab7a28c67512b784d8992", size = 813312, upload-time = "2026-01-15T20:46:39.725Z" }, ] [[package]]