From 737c78da30d938918cfd2026f0b064b5666f3dc6 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Tue, 11 Mar 2025 14:54:48 -0700 Subject: [PATCH] still no tokens :( --- evals/BraintrustClient.ts | 331 ++++++++++++++++++++++++++++++++++++++ evals/initStagehand.ts | 9 +- evals/taskConfig.ts | 12 +- lib/llm/LLMClient.ts | 8 +- lib/llm/LLMProvider.ts | 10 ++ types/model.ts | 22 ++- 6 files changed, 378 insertions(+), 14 deletions(-) create mode 100644 evals/BraintrustClient.ts diff --git a/evals/BraintrustClient.ts b/evals/BraintrustClient.ts new file mode 100644 index 000000000..0d57bb105 --- /dev/null +++ b/evals/BraintrustClient.ts @@ -0,0 +1,331 @@ +import OpenAI from "openai"; +import type { ClientOptions } from "openai"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { LogLine } from "../types/log"; +import { AvailableModel } from "../types/model"; +import { LLMCache } from "../lib/cache/LLMCache"; +import { + ChatMessage, + CreateChatCompletionOptions, + LLMClient, + LLMResponse, +} from "../lib/llm/LLMClient"; +import { wrapOpenAI } from "braintrust"; + +export class BraintrustClient extends LLMClient { + public type = "braintrust" as const; + private client: OpenAI; + private cache: LLMCache | undefined; + private enableCaching: boolean; + public clientOptions: ClientOptions; + public hasVision = false; + + constructor({ + enableCaching = false, + cache, + modelName, + clientOptions, + userProvidedInstructions, + }: { + logger: (message: LogLine) => void; + enableCaching?: boolean; + cache?: LLMCache; + modelName: AvailableModel; + clientOptions?: ClientOptions; + userProvidedInstructions?: string; + }) { + super(modelName, userProvidedInstructions); + + // Create OpenAI client with the base URL set to Braintrust API + this.client = wrapOpenAI( + new OpenAI({ + baseURL: "https://api.braintrust.dev/v1/proxy", + apiKey: clientOptions?.apiKey || process.env.BRAINTRUST_API_KEY, + ...clientOptions, + }), + ); + + this.cache = cache; + this.enableCaching = enableCaching; + this.modelName = modelName; + this.clientOptions = clientOptions; + } + + async createChatCompletion({ + options, + retries, + logger, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "braintrust", + message: "creating chat completion", + level: 1, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Braintrust API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Braintrust only supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + try { + // Use OpenAI client with Braintrust API + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName, + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + }); + + // Format the response to match the expected LLMResponse format + const response: LLMResponse = { + id: apiResponse.id, + object: "chat.completion", + created: Date.now(), + model: this.modelName, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: apiResponse.choices[0]?.message?.content || null, + tool_calls: apiResponse.choices[0]?.message?.tool_calls || [], + }, + finish_reason: apiResponse.choices[0]?.finish_reason || "stop", + }, + ], + usage: { + prompt_tokens: apiResponse.usage?.prompt_tokens || 0, + completion_tokens: apiResponse.usage?.completion_tokens || 0, + total_tokens: apiResponse.usage?.total_tokens || 0, + }, + }; + + logger({ + category: "braintrust", + message: "response", + level: 1, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (options.response_model) { + // First try standard function calling format + const toolCall = response.choices[0]?.message?.tool_calls?.[0]; + if (toolCall?.function?.arguments) { + try { + const result = JSON.parse(toolCall.function.arguments); + if (this.enableCaching) { + this.cache.set(cacheOptions, result, options.requestId); + } + return result as T; + } catch (e) { + // If JSON parse fails, the model might be returning a different format + logger({ + category: "braintrust", + message: "failed to parse tool call arguments as JSON, retrying", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + }, + }); + } + } + + // If we have content but no tool calls, try to parse the content as JSON + const content = response.choices[0]?.message?.content; + if (content) { + try { + // Try to extract JSON from the content + const jsonMatch = content.match(/\{[\s\S]*\}/); + if (jsonMatch) { + const result = JSON.parse(jsonMatch[0]); + if (this.enableCaching) { + this.cache.set(cacheOptions, result, options.requestId); + } + return result as T; + } + } catch (e) { + logger({ + category: "braintrust", + message: "failed to parse content as JSON", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + }, + }); + } + } + + // If we still haven't found valid JSON and have retries left, try again + if (!retries || retries < 5) { + return this.createChatCompletion({ + options, + logger, + retries: (retries ?? 0) + 1, + }); + } + + throw new Error( + "Create Chat Completion Failed: Could not extract valid JSON from response", + ); + } + + if (this.enableCaching) { + this.cache.set(cacheOptions, response, options.requestId); + } + + return response as T; + } catch (error) { + logger({ + category: "braintrust", + message: "error creating chat completion", + level: 1, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + throw error; + } + } +} diff --git a/evals/initStagehand.ts b/evals/initStagehand.ts index bb26c9b30..fa6190ee1 100644 --- a/evals/initStagehand.ts +++ b/evals/initStagehand.ts @@ -31,9 +31,9 @@ const StagehandConfig = { headless: false, enableCaching, domSettleTimeoutMs: 30_000, - modelName: "gpt-4o", // default model, can be overridden by initStagehand arguments + modelName: "braintrust-gpt-4o", // default model, can be overridden by initStagehand arguments modelClientOptions: { - apiKey: process.env.OPENAI_API_KEY, + apiKey: process.env.BRAINTRUST_API_KEY, }, logger: (logLine: LogLine) => console.log(`[stagehand::${logLine.category}] ${logLine.message}`), @@ -63,10 +63,7 @@ export const initStagehand = async ({ configOverrides?: Partial; actTimeoutMs?: number; }) => { - let chosenApiKey: string | undefined = process.env.OPENAI_API_KEY; - if (modelName.startsWith("claude")) { - chosenApiKey = process.env.ANTHROPIC_API_KEY; - } + const chosenApiKey = process.env.BRAINTRUST_API_KEY; const config = { ...StagehandConfig, diff --git a/evals/taskConfig.ts b/evals/taskConfig.ts index 0031b8a80..2e194ff54 100644 --- a/evals/taskConfig.ts +++ b/evals/taskConfig.ts @@ -49,7 +49,17 @@ if (filterByEvalName && !tasksByName[filterByEvalName]) { */ const DEFAULT_EVAL_MODELS = process.env.EVAL_MODELS ? process.env.EVAL_MODELS.split(",") - : ["gpt-4o", "claude-3-5-sonnet-latest"]; + : [ + "braintrust-gpt-4o", + "braintrust-gpt-4.5-preview", + "braintrust-gpt-4o-mini", + "braintrust-claude-3-5-sonnet-latest", + "braintrust-claude-3-7-sonnet-latest", + "braintrust-gemini-2.0-flash", + "braintrust-llama-3.3-70b-versatile", + "braintrust-llama-3.1-8b-instant", + "braintrust-deepseek-r1-distill-llama-70b", + ]; /** * getModelList: diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 1f060a510..f3fffd59f 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -81,7 +81,13 @@ export interface CreateChatCompletionOptions { } export abstract class LLMClient { - public type: "openai" | "anthropic" | "cerebras" | "groq" | string; + public type: + | "openai" + | "anthropic" + | "cerebras" + | "groq" + | "braintrust" + | string; public modelName: AvailableModel; public hasVision: boolean; public clientOptions: ClientOptions; diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 3ee3f29d8..d2e8cabcb 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -10,6 +10,7 @@ import { CerebrasClient } from "./CerebrasClient"; import { GroqClient } from "./GroqClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; +import { BraintrustClient } from "../../evals/BraintrustClient"; const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gpt-4o": "openai", @@ -64,6 +65,15 @@ export class LLMProvider { modelName: AvailableModel, clientOptions?: ClientOptions, ): LLMClient { + if (modelName.startsWith("braintrust-")) { + return new BraintrustClient({ + logger: this.logger, + enableCaching: this.enableCaching, + cache: this.cache, + modelName: modelName.split("braintrust-")[1] as AvailableModel, + clientOptions, + }); + } const provider = modelToProviderMap[modelName]; if (!provider) { throw new Error(`Unsupported model: ${modelName}`); diff --git a/types/model.ts b/types/model.ts index 666cb6818..2547da0d8 100644 --- a/types/model.ts +++ b/types/model.ts @@ -2,28 +2,38 @@ import type { ClientOptions as AnthropicClientOptions } from "@anthropic-ai/sdk" import type { ClientOptions as OpenAIClientOptions } from "openai"; import { z } from "zod"; -export const AvailableModelSchema = z.enum([ +// Create a base schema for specific known models +const BaseModelSchema = z.enum([ "gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", - "gpt-4.5-preview", "claude-3-5-sonnet-latest", "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", - "claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219", "o1-mini", "o1-preview", "o3-mini", "cerebras-llama-3.3-70b", "cerebras-llama-3.1-8b", - "groq-llama-3.3-70b-versatile", - "groq-llama-3.3-70b-specdec", +]); + +// Create a schema that also accepts any string starting with "braintrust-" +export const AvailableModelSchema = z.union([ + BaseModelSchema, + z.string().refine((val) => val.startsWith("braintrust-"), { + message: "Braintrust models must start with 'braintrust-'", + }), ]); export type AvailableModel = z.infer; -export type ModelProvider = "openai" | "anthropic" | "cerebras" | "groq"; +export type ModelProvider = + | "openai" + | "anthropic" + | "cerebras" + | "braintrust" + | "groq"; export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions;