diff --git a/pyproject.toml b/pyproject.toml index 0e1fcf8c..a7de3c86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ trackio = [ "trackio<1.0.0", ] verifiers = [ - "verifiers", + "verifiers>=0.1.8.post0", "openai", ] all = [ diff --git a/tinker_cookbook/recipes/verifiers_rl/tinker_openai.py b/tinker_cookbook/recipes/verifiers_rl/tinker_openai.py index 153fe075..1c213b17 100644 --- a/tinker_cookbook/recipes/verifiers_rl/tinker_openai.py +++ b/tinker_cookbook/recipes/verifiers_rl/tinker_openai.py @@ -11,49 +11,21 @@ from __future__ import annotations import time -from typing import Any, Callable, Dict, List, Optional, overload, Literal +from typing import Any, Dict, List, Literal, overload import tinker -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.completion import Completion from openai import AsyncOpenAI +from openai._streaming import AsyncStream from openai.resources.chat import AsyncChat as OpenAIAsyncChat from openai.resources.chat.completions import AsyncCompletions as OpenAIAsyncChatCompletions from openai.resources.completions import AsyncCompletions as OpenAIAsyncCompletions -from openai._streaming import AsyncStream +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.completion import Completion from tinker_cookbook import renderers from tinker_cookbook.tokenizer_utils import Tokenizer -GenerationHook = Callable[ - [List[renderers.Message], tinker.ModelInput, List[int], List[float]], None -] - - -def convert_oai_messages_to_renderer_messages( - messages: List[Dict[str, Any]], -) -> List[renderers.Message]: - out: List[renderers.Message] = [] - for m in messages: - role = str(m.get("role", "user")) - content = m.get("content", "") - # extract text from list of content parts if necessary - if isinstance(content, list): - text_parts: List[str] = [] - for part in content: - if isinstance(part, dict): - if "text" in part: - text_parts.append(str(part["text"])) - elif isinstance(part, str): - text_parts.append(part) - content = "".join(text_parts) - else: - content = str(content) - out.append(renderers.Message(role=role, content=content)) - return out - - class TinkerAsyncOpenAIClient(AsyncOpenAI): """ OpenAI-compatible async client that routes calls to a Tinker SamplingClient. @@ -69,10 +41,6 @@ def __init__( self.sampling_client = sampling_client self.renderer = renderer self.tokenizer = tokenizer - self.hook: Optional[GenerationHook] = None - - def set_generation_hook(self, hook: Optional[GenerationHook]) -> None: - self.hook = hook def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: self.sampling_client = sampling_client @@ -106,16 +74,18 @@ async def create(self, *args: Any, stream: bool, **kwargs: Any) -> ChatCompletio async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStream[Any]: model = kwargs.get("model", "tinker") messages = kwargs.get("messages", []) + if kwargs.get("tools"): + raise NotImplementedError("Tool calling is not yet supported by this model's renderer.") if kwargs.get("stream", False): raise ValueError("stream=True not supported by TinkerAsyncOpenAIClient") sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools")} - # prepare prompt - conv_messages = convert_oai_messages_to_renderer_messages(messages) stop = sampling_args.get("stop", self._parent.renderer.get_stop_sequences()) max_tokens = sampling_args.get("max_tokens") or sampling_args.get("max_completion_tokens") - model_input = self._parent.renderer.build_generation_prompt(conv_messages) + model_input = self._parent.renderer.build_generation_prompt(messages) + prompt_token_ids: List[int] = model_input.to_ints() + sample = await self._parent.sampling_client.sample_async( prompt=model_input, num_samples=1, @@ -128,15 +98,12 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea ), ) seq = sample.sequences[0] - tokens: List[int] = seq.tokens - logprobs: List[float] = seq.logprobs or [0.0] * len(tokens) - - if self._parent.hook is not None: - self._parent.hook(conv_messages, model_input, tokens, logprobs) + completion_token_ids: List[int] = seq.tokens + logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids) - # build ChatCompletion via pydantic validation using renderer parsing - assistant_message, parse_success = self._parent.renderer.parse_response(tokens) - content_text = assistant_message["content"] + assistant_message, parse_success = self._parent.renderer.parse_response( + completion_token_ids + ) finish_reason = "stop" if parse_success else "length" response_dict: Dict[str, Any] = { "id": "tinker-chatcmpl", @@ -146,23 +113,28 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea "choices": [ { "index": 0, - "message": {"role": "assistant", "content": content_text}, + "message": assistant_message, "finish_reason": finish_reason, "logprobs": { "content": [ - {"token": f"token_id:{tid}", "logprob": float(lp), "top_logprobs": []} - for tid, lp in zip(tokens, logprobs) + {"token": f"token_id:{tid}", "logprob": lp, "top_logprobs": []} + for tid, lp in zip(completion_token_ids, logprobs) ] }, } ], "usage": { - "prompt_tokens": model_input.length, - "completion_tokens": len(tokens), - "total_tokens": model_input.length + len(tokens), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(completion_token_ids), + "total_tokens": len(prompt_token_ids) + len(completion_token_ids), }, } - return ChatCompletion.model_validate(response_dict) + response = ChatCompletion.model_validate(response_dict) + + setattr(response, "prompt_token_ids", prompt_token_ids) + setattr(response.choices[0], "token_ids", completion_token_ids) + + return response class TinkerCompletions(OpenAIAsyncCompletions): @@ -190,10 +162,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co prompt = kwargs.get("prompt", "") sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")} - # Completion-mode: render prompt directly as text chunk - model_input = tinker.ModelInput.from_ints( - self._parent.tokenizer.encode(prompt, add_special_tokens=True) - ) + prompt_token_ids: List[int] = self._parent.tokenizer.encode(prompt, add_special_tokens=True) + model_input = tinker.ModelInput.from_ints(prompt_token_ids) + sample = await self._parent.sampling_client.sample_async( prompt=model_input, num_samples=1, @@ -205,11 +176,11 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co ), ) seq = sample.sequences[0] - tokens: List[int] = seq.tokens - logprobs: List[float] = seq.logprobs or [0.0] * len(tokens) + completion_token_ids: List[int] = seq.tokens + logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids) - text = self._parent.tokenizer.decode(tokens) - tokens_str = [f"token_id:{tid}" for tid in tokens] + text = self._parent.tokenizer.decode(completion_token_ids) + tokens_str = [f"token_id:{tid}" for tid in completion_token_ids] response_dict: Dict[str, Any] = { "id": "tinker-cmpl", "object": "text_completion", @@ -222,20 +193,24 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co "finish_reason": "stop", "logprobs": { "tokens": tokens_str, - "token_logprobs": [float(lp) for lp in logprobs], + "token_logprobs": logprobs, }, } ], "usage": { - "prompt_tokens": model_input.length, - "completion_tokens": len(tokens), - "total_tokens": model_input.length + len(tokens), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(completion_token_ids), + "total_tokens": len(prompt_token_ids) + len(completion_token_ids), }, } - final = Completion.model_validate(response_dict) + response = Completion.model_validate(response_dict) + + setattr(response.choices[0], "prompt_token_ids", prompt_token_ids) + setattr(response.choices[0], "token_ids", completion_token_ids) + if stream: - return TinkerAsyncCompletionStream(final) - return final + return TinkerAsyncCompletionStream(response) + return response class TinkerAsyncChat(OpenAIAsyncChat): diff --git a/tinker_cookbook/recipes/verifiers_rl/train.py b/tinker_cookbook/recipes/verifiers_rl/train.py index d1c78162..c49212fe 100644 --- a/tinker_cookbook/recipes/verifiers_rl/train.py +++ b/tinker_cookbook/recipes/verifiers_rl/train.py @@ -1,25 +1,25 @@ from __future__ import annotations import asyncio +import json import logging -from typing import Any, List from datetime import datetime +from typing import Any, cast import chz -import json -import tinker -import verifiers as vf from verifiers.utils.async_utils import maybe_semaphore + from tinker_cookbook import cli_utils, model_info, renderers -from tinker_cookbook.completers import TokensWithLogprobs, TokenCompleter, TinkerTokenCompleter +from tinker_cookbook.completers import TinkerTokenCompleter, TokenCompleter from tinker_cookbook.recipes.verifiers_rl.tinker_openai import TinkerAsyncOpenAIClient -from tinker_cookbook.rl import train -from tinker_cookbook.rl.types import EnvGroupBuilder, Trajectory, Transition, TrajectoryGroup -from tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer from tinker_cookbook.recipes.verifiers_rl.verifiers_env import ( VerifiersEnvGroupBuilder, VerifiersRLDatasetBuilder, + convert_states_to_trajectory_group, ) +from tinker_cookbook.rl import train +from tinker_cookbook.rl.types import EnvGroupBuilder, TrajectoryGroup +from tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer logger = logging.getLogger(__name__) @@ -42,7 +42,10 @@ class CLIConfig: num_substeps: int = 1 learning_rate: float = 1e-5 max_tokens: int = 512 + temperature: float = 1.0 kl_penalty_coef: float = 0.0 + max_concurrent_generation: int = -1 + max_concurrent_scoring: int = -1 # logging configuration eval_every: int = 0 @@ -61,27 +64,19 @@ async def cli_main(cli_config: CLIConfig, env: Any | None): f"_lr{cli_config.learning_rate}_rank{cli_config.lora_rank}_{date_and_time}" ) - if cli_config.log_path is not None: - log_path = cli_config.log_path - else: - log_path = f"/tmp/tinker-examples/verifiers_rl/{run_name}" - + log_path = cli_config.log_path or f"/tmp/tinker-examples/verifiers_rl/{run_name}" cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) - # load verifiers environment (must be installed; `prime env install user/env-id`) env_args = json.loads(cli_config.vf_env_args) if cli_config.vf_env_args else {} - vf_env = vf.load_environment(cli_config.vf_env_id, **env_args) - # global objects shared across rollout groups + shared_client: TinkerAsyncOpenAIClient | None = None shared_renderer: renderers.Renderer | None = None local_tokenizer: Tokenizer | None = None async def custom_do_group_rollout( builder: EnvGroupBuilder, policy: TokenCompleter ) -> TrajectoryGroup: - assert isinstance(builder, VerifiersEnvGroupBuilder) - assert isinstance(policy, TinkerTokenCompleter) - nonlocal shared_renderer, local_tokenizer + nonlocal shared_client, shared_renderer, local_tokenizer # initialize tokenizer and renderer lazily if local_tokenizer is None: @@ -89,85 +84,52 @@ async def custom_do_group_rollout( if shared_renderer is None: renderer_name = model_info.get_recommended_renderer_name(cli_config.model_name) shared_renderer = renderers.get_renderer(renderer_name, local_tokenizer) - sampling_client = policy.sampling_client - - async def run_one_rollout() -> tuple[Trajectory, float, dict[str, float | int]]: - recorded: List[ - tuple[list[renderers.Message], tinker.ModelInput, list[int], list[float]] - ] = [] - def hook(messages, model_input, tokens, logprobs): - recorded.append((list(messages), model_input, list(tokens), list(logprobs))) - - # create per-rollout client for hook - assert shared_renderer is not None and local_tokenizer is not None - local_client = TinkerAsyncOpenAIClient( + sampling_client = cast(TinkerTokenCompleter, policy).sampling_client + if shared_client is None: + shared_client = TinkerAsyncOpenAIClient( sampling_client, shared_renderer, local_tokenizer ) - local_client.set_generation_hook(hook) - - rollout_input: vf.RolloutInput = { - "prompt": builder.prompt, - "answer": builder.answer, - "task": builder.task, - "info": builder.info, - "example_id": 0, - } - state = await builder.vf_env.rollout( - input=rollout_input, - client=local_client, - model="tinker", - sampling_args={}, - ) + else: + shared_client.set_sampling_client(sampling_client) - score_sem = await maybe_semaphore(None) - await builder.vf_env.rubric.score_rollout( - state=state, - score_sem=score_sem, - ) - rs: vf.RolloutScore = {"reward": state["reward"], "metrics": state.get("metrics", {})} - - transitions: List[Transition] = [] - for _msgs, model_input, tokens, logprobs in recorded: - transitions.append( - Transition( - ob=model_input, - ac=TokensWithLogprobs(tokens=tokens, maybe_logprobs=logprobs), - reward=0.0, - episode_done=False, - metrics={}, - ) - ) - if transitions: - transitions[-1] = Transition( - ob=transitions[-1].ob, - ac=transitions[-1].ac, - reward=0.0, - episode_done=True, - metrics=transitions[-1].metrics, - ) - traj = Trajectory(transitions=transitions, final_ob=tinker.ModelInput.empty()) - return traj, float(rs["reward"]), dict(rs["metrics"]) - - results = await asyncio.gather(*[run_one_rollout() for _ in range(cli_config.group_size)]) - trajectories_G = [t for (t, _r, _m) in results] - final_rewards_G = [r for (_t, r, _m) in results] - metrics_G = [m for (_t, _r, m) in results] - return TrajectoryGroup(trajectories_G, final_rewards_G, metrics_G) + vf_builder = cast(VerifiersEnvGroupBuilder, builder) + rollout_inputs = vf_builder.get_rollout_inputs(cli_config.group_size) + + gen_sem = await maybe_semaphore(cli_config.max_concurrent_generation) + score_sem = await maybe_semaphore(cli_config.max_concurrent_scoring) + + states = await vf_builder.vf_env.run_group( + group_inputs=rollout_inputs, + client=shared_client, + model="tinker", + gen_sampling_args={ + "max_tokens": cli_config.max_tokens, + "temperature": cli_config.temperature, + }, + gen_sem=gen_sem, + score_sem=score_sem, + ) + + return convert_states_to_trajectory_group(states) # override do_group_rollout function inside rl.train train.do_group_rollout = custom_do_group_rollout + dataset_builder = VerifiersRLDatasetBuilder( + vf_env_id=cli_config.vf_env_id, + vf_env_args=env_args, + groups_per_batch=cli_config.groups_per_batch, + dataset_n=cli_config.dataset_n, + dataset_seed=cli_config.dataset_seed, + ) + cfg = train.Config( learning_rate=cli_config.learning_rate, - dataset_builder=VerifiersRLDatasetBuilder( - vf_env=vf_env, - groups_per_batch=cli_config.groups_per_batch, - dataset_n=cli_config.dataset_n, - dataset_seed=cli_config.dataset_seed, - ), + dataset_builder=dataset_builder, model_name=cli_config.model_name, max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, lora_rank=cli_config.lora_rank, kl_penalty_coef=cli_config.kl_penalty_coef, num_substeps=cli_config.num_substeps, diff --git a/tinker_cookbook/recipes/verifiers_rl/verifiers_env.py b/tinker_cookbook/recipes/verifiers_rl/verifiers_env.py index 8fda9b53..86371204 100644 --- a/tinker_cookbook/recipes/verifiers_rl/verifiers_env.py +++ b/tinker_cookbook/recipes/verifiers_rl/verifiers_env.py @@ -1,10 +1,80 @@ from __future__ import annotations +from contextvars import ContextVar from typing import Sequence import chz +import tinker import verifiers as vf -from tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, RLDatasetBuilder + +from tinker_cookbook.completers import TokensWithLogprobs +from tinker_cookbook.rl.types import ( + EnvGroupBuilder, + RLDataset, + RLDatasetBuilder, + Trajectory, + TrajectoryGroup, + Transition, +) + +_vf_env_ctx: ContextVar[vf.Environment | None] = ContextVar("vf_env", default=None) + + +def set_vf_env(env: vf.Environment) -> None: + """Set the verifiers environment for the current context.""" + _vf_env_ctx.set(env) + + +def get_vf_env() -> vf.Environment | None: + """Get the verifiers environment from the current context.""" + return _vf_env_ctx.get() + + +def convert_states_to_trajectory_group(states: list[vf.State]) -> TrajectoryGroup: + """Convert verifiers States to tinker TrajectoryGroup.""" + trajectories_G: list[Trajectory] = [] + final_rewards_G: list[float] = [] + metrics_G: list[dict[str, float | int]] = [] + + for state in states: + transitions: list[Transition] = [] + trajectory_steps = state.get("trajectory", []) + + for i, step in enumerate(trajectory_steps): + tokens_data = step.get("tokens") + if tokens_data is not None: + prompt_ids = tokens_data.get("prompt_ids", []) + ob = tinker.ModelInput.from_ints(prompt_ids) + completion_ids = tokens_data.get("completion_ids", []) + completion_logprobs = tokens_data.get("completion_logprobs", []) + ac = TokensWithLogprobs( + tokens=completion_ids, + maybe_logprobs=completion_logprobs, + ) + else: + ob = tinker.ModelInput.empty() + ac = TokensWithLogprobs(tokens=[], maybe_logprobs=[]) + + is_last = i == len(trajectory_steps) - 1 + transition = Transition( + ob=ob, + ac=ac, + reward=0.0, + episode_done=is_last, + metrics={}, + ) + transitions.append(transition) + + trajectory = Trajectory(transitions=transitions, final_ob=tinker.ModelInput.empty()) + trajectories_G.append(trajectory) + final_rewards_G.append(state.get("reward") or 0.0) + metrics_G.append(state.get("metrics") or {}) + + return TrajectoryGroup( + trajectories_G=trajectories_G, + final_rewards_G=final_rewards_G, + metrics_G=metrics_G, + ) class VerifiersRLDataset(RLDataset): @@ -31,8 +101,9 @@ def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: VerifiersEnvGroupBuilder( vf_env=self.vf_env, prompt=row["prompt"], + example_id=row["example_id"], + task=row["task"], answer=row.get("answer", ""), - task=row.get("task", "default"), info=row.get("info", {}), ) ) @@ -41,23 +112,29 @@ def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: @chz.chz class VerifiersRLDatasetBuilder(RLDatasetBuilder): - vf_env: vf.Environment - groups_per_batch: int - dataset_n: int - dataset_seed: int | None + vf_env_id: str + vf_env_args: dict = chz.field(default_factory=dict) + groups_per_batch: int = 32 + dataset_n: int = -1 + dataset_seed: int | None = None async def __call__(self) -> tuple[RLDataset, RLDataset | None]: - ds = self.vf_env.get_dataset(n=self.dataset_n, seed=self.dataset_seed) + vf_env = get_vf_env() + if vf_env is None: + vf_env = vf.load_environment(self.vf_env_id, **self.vf_env_args) + set_vf_env(vf_env) + ds = vf_env.get_dataset(n=self.dataset_n, seed=self.dataset_seed) rows = [ { "prompt": ds["prompt"][i], + "example_id": ds["example_id"][i], + "task": ds["task"][i], **({"answer": ds["answer"][i]} if "answer" in ds.column_names else {}), - **({"task": ds["task"][i]} if "task" in ds.column_names else {}), **({"info": ds["info"][i]} if "info" in ds.column_names else {}), } for i in range(len(ds)) ] - return VerifiersRLDataset(rows, self.vf_env, self.groups_per_batch), None + return VerifiersRLDataset(rows, vf_env, self.groups_per_batch), None class VerifiersEnvGroupBuilder(EnvGroupBuilder): @@ -65,18 +142,32 @@ def __init__( self, vf_env: vf.Environment, prompt: vf.Messages, - answer: str, + example_id: int, task: str, - info: dict, + answer: str = "", + info: dict | None = None, ): self.vf_env = vf_env self.prompt = prompt - self.answer = answer + self.example_id = example_id self.task = task - self.info = info + self.answer = answer + self.info = info or {} + + def get_rollout_inputs(self, group_size: int) -> list[vf.RolloutInput]: + return [ + vf.RolloutInput( + prompt=self.prompt, + answer=self.answer, + task=self.task, + info=self.info, + example_id=self.example_id, + ) + for _ in range(group_size) + ] async def make_envs(self): return [] # unused when using custom_do_group_rollout - def logging_tags(self): + def logging_tags(self) -> list[str]: return [self.task] if self.task else []