Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ trackio = [
"trackio<1.0.0",
]
verifiers = [
"verifiers",
"verifiers>=0.1.8.post0",
"openai",
]
all = [
Expand Down
113 changes: 44 additions & 69 deletions tinker_cookbook/recipes/verifiers_rl/tinker_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,21 @@
from __future__ import annotations

import time
from typing import Any, Callable, Dict, List, Optional, overload, Literal
from typing import Any, Dict, List, Literal, overload

import tinker
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai import AsyncOpenAI
from openai._streaming import AsyncStream
from openai.resources.chat import AsyncChat as OpenAIAsyncChat
from openai.resources.chat.completions import AsyncCompletions as OpenAIAsyncChatCompletions
from openai.resources.completions import AsyncCompletions as OpenAIAsyncCompletions
from openai._streaming import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion

from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import Tokenizer


GenerationHook = Callable[
[List[renderers.Message], tinker.ModelInput, List[int], List[float]], None
]


def convert_oai_messages_to_renderer_messages(
messages: List[Dict[str, Any]],
) -> List[renderers.Message]:
out: List[renderers.Message] = []
for m in messages:
role = str(m.get("role", "user"))
content = m.get("content", "")
# extract text from list of content parts if necessary
if isinstance(content, list):
text_parts: List[str] = []
for part in content:
if isinstance(part, dict):
if "text" in part:
text_parts.append(str(part["text"]))
elif isinstance(part, str):
text_parts.append(part)
content = "".join(text_parts)
else:
content = str(content)
out.append(renderers.Message(role=role, content=content))
return out


class TinkerAsyncOpenAIClient(AsyncOpenAI):
"""
OpenAI-compatible async client that routes calls to a Tinker SamplingClient.
Expand All @@ -69,10 +41,6 @@ def __init__(
self.sampling_client = sampling_client
self.renderer = renderer
self.tokenizer = tokenizer
self.hook: Optional[GenerationHook] = None

def set_generation_hook(self, hook: Optional[GenerationHook]) -> None:
self.hook = hook

def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None:
self.sampling_client = sampling_client
Expand Down Expand Up @@ -106,16 +74,18 @@ async def create(self, *args: Any, stream: bool, **kwargs: Any) -> ChatCompletio
async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStream[Any]:
model = kwargs.get("model", "tinker")
messages = kwargs.get("messages", [])
if kwargs.get("tools"):
raise NotImplementedError("Tool calling is not yet supported by this model's renderer.")
if kwargs.get("stream", False):
raise ValueError("stream=True not supported by TinkerAsyncOpenAIClient")
sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools")}

# prepare prompt
conv_messages = convert_oai_messages_to_renderer_messages(messages)
stop = sampling_args.get("stop", self._parent.renderer.get_stop_sequences())
max_tokens = sampling_args.get("max_tokens") or sampling_args.get("max_completion_tokens")

model_input = self._parent.renderer.build_generation_prompt(conv_messages)
model_input = self._parent.renderer.build_generation_prompt(messages)
prompt_token_ids: List[int] = model_input.to_ints()

sample = await self._parent.sampling_client.sample_async(
prompt=model_input,
num_samples=1,
Expand All @@ -128,15 +98,12 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
),
)
seq = sample.sequences[0]
tokens: List[int] = seq.tokens
logprobs: List[float] = seq.logprobs or [0.0] * len(tokens)

if self._parent.hook is not None:
self._parent.hook(conv_messages, model_input, tokens, logprobs)
completion_token_ids: List[int] = seq.tokens
logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids)

# build ChatCompletion via pydantic validation using renderer parsing
assistant_message, parse_success = self._parent.renderer.parse_response(tokens)
content_text = assistant_message["content"]
assistant_message, parse_success = self._parent.renderer.parse_response(
completion_token_ids
)
finish_reason = "stop" if parse_success else "length"
response_dict: Dict[str, Any] = {
"id": "tinker-chatcmpl",
Expand All @@ -146,23 +113,28 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content_text},
"message": assistant_message,
"finish_reason": finish_reason,
"logprobs": {
"content": [
{"token": f"token_id:{tid}", "logprob": float(lp), "top_logprobs": []}
for tid, lp in zip(tokens, logprobs)
{"token": f"token_id:{tid}", "logprob": lp, "top_logprobs": []}
for tid, lp in zip(completion_token_ids, logprobs)
]
},
}
],
"usage": {
"prompt_tokens": model_input.length,
"completion_tokens": len(tokens),
"total_tokens": model_input.length + len(tokens),
"prompt_tokens": len(prompt_token_ids),
"completion_tokens": len(completion_token_ids),
"total_tokens": len(prompt_token_ids) + len(completion_token_ids),
},
}
return ChatCompletion.model_validate(response_dict)
response = ChatCompletion.model_validate(response_dict)

setattr(response, "prompt_token_ids", prompt_token_ids)
setattr(response.choices[0], "token_ids", completion_token_ids)

return response


class TinkerCompletions(OpenAIAsyncCompletions):
Expand Down Expand Up @@ -190,10 +162,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
prompt = kwargs.get("prompt", "")
sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")}

# Completion-mode: render prompt directly as text chunk
model_input = tinker.ModelInput.from_ints(
self._parent.tokenizer.encode(prompt, add_special_tokens=True)
)
prompt_token_ids: List[int] = self._parent.tokenizer.encode(prompt, add_special_tokens=True)
model_input = tinker.ModelInput.from_ints(prompt_token_ids)

sample = await self._parent.sampling_client.sample_async(
prompt=model_input,
num_samples=1,
Expand All @@ -205,11 +176,11 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
),
)
seq = sample.sequences[0]
tokens: List[int] = seq.tokens
logprobs: List[float] = seq.logprobs or [0.0] * len(tokens)
completion_token_ids: List[int] = seq.tokens
logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids)

text = self._parent.tokenizer.decode(tokens)
tokens_str = [f"token_id:{tid}" for tid in tokens]
text = self._parent.tokenizer.decode(completion_token_ids)
tokens_str = [f"token_id:{tid}" for tid in completion_token_ids]
response_dict: Dict[str, Any] = {
"id": "tinker-cmpl",
"object": "text_completion",
Expand All @@ -222,20 +193,24 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
"finish_reason": "stop",
"logprobs": {
"tokens": tokens_str,
"token_logprobs": [float(lp) for lp in logprobs],
"token_logprobs": logprobs,
},
}
],
"usage": {
"prompt_tokens": model_input.length,
"completion_tokens": len(tokens),
"total_tokens": model_input.length + len(tokens),
"prompt_tokens": len(prompt_token_ids),
"completion_tokens": len(completion_token_ids),
"total_tokens": len(prompt_token_ids) + len(completion_token_ids),
},
}
final = Completion.model_validate(response_dict)
response = Completion.model_validate(response_dict)

setattr(response.choices[0], "prompt_token_ids", prompt_token_ids)
setattr(response.choices[0], "token_ids", completion_token_ids)

if stream:
return TinkerAsyncCompletionStream(final)
return final
return TinkerAsyncCompletionStream(response)
return response


class TinkerAsyncChat(OpenAIAsyncChat):
Expand Down
Loading
Loading