From 80fbb5d1d24a9b2b7bd66574f8679f9646f8a822 Mon Sep 17 00:00:00 2001 From: Anirudh Kamath Date: Fri, 28 Feb 2025 07:33:22 -0800 Subject: [PATCH 1/2] add gemini support via openai --- examples/example.ts | 24 ++- lib/llm/GoogleClient.ts | 328 ++++++++++++++++++++++++++++++++++++++++ lib/llm/LLMClient.ts | 2 +- lib/llm/LLMProvider.ts | 11 ++ types/model.ts | 4 +- 5 files changed, 365 insertions(+), 4 deletions(-) create mode 100644 lib/llm/GoogleClient.ts diff --git a/examples/example.ts b/examples/example.ts index 41d8d861f..b5411f511 100644 --- a/examples/example.ts +++ b/examples/example.ts @@ -5,15 +5,35 @@ * npx create-browser-app@latest my-browser-app */ -import { Stagehand } from "@/dist"; +import { AvailableModel, Stagehand } from "@/dist"; import StagehandConfig from "@/stagehand.config"; - +import { z } from "zod"; async function example() { + const modelName = "cerebras-llama-3.3-70b"; + // const modelName = "gemini-2.0-flash"; const stagehand = new Stagehand({ ...StagehandConfig, + env: "LOCAL", + modelName, + modelClientOptions: { + apiKey: + modelName === ("gemini-2.0-flash" as AvailableModel) + ? process.env.GOOGLE_API_KEY + : process.env.CEREBRAS_API_KEY, + }, }); await stagehand.init(); await stagehand.page.goto("https://docs.stagehand.dev"); + await stagehand.page.act("Click the quickstart"); + const { text } = await stagehand.page.extract({ + instruction: "Extract the title", + schema: z.object({ + text: z.string(), + }), + useTextExtract: true, + }); + console.log(text); + await stagehand.close(); } (async () => { diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts new file mode 100644 index 000000000..52337718c --- /dev/null +++ b/lib/llm/GoogleClient.ts @@ -0,0 +1,328 @@ +import OpenAI from "openai"; +import type { ClientOptions } from "openai"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { LogLine } from "../../types/log"; +import { AvailableModel } from "../../types/model"; +import { LLMCache } from "../cache/LLMCache"; +import { + ChatMessage, + CreateChatCompletionOptions, + LLMClient, + LLMResponse, +} from "./LLMClient"; + +export class GoogleClient extends LLMClient { + public type = "google" as const; + private client: OpenAI; + private cache: LLMCache | undefined; + private enableCaching: boolean; + public clientOptions: ClientOptions; + public hasVision = false; + + constructor({ + enableCaching = false, + cache, + modelName, + clientOptions, + userProvidedInstructions, + }: { + logger: (message: LogLine) => void; + enableCaching?: boolean; + cache?: LLMCache; + modelName: AvailableModel; + clientOptions?: ClientOptions; + userProvidedInstructions?: string; + }) { + super(modelName, userProvidedInstructions); + + // Create OpenAI client with the base URL set to Google API + this.client = new OpenAI({ + baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/", + apiKey: clientOptions?.apiKey || process.env.GEMINI_API_KEY, + ...clientOptions, + }); + + this.cache = cache; + this.enableCaching = enableCaching; + this.modelName = modelName; + this.clientOptions = clientOptions; + } + + async createChatCompletion({ + options, + retries, + logger, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "google", + message: "creating chat completion", + level: 1, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Google API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Google only supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + try { + // Use OpenAI client with Google API + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName, + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + }); + + // Format the response to match the expected LLMResponse format + const response: LLMResponse = { + id: apiResponse.id, + object: "chat.completion", + created: Date.now(), + model: this.modelName, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: apiResponse.choices[0]?.message?.content || null, + tool_calls: apiResponse.choices[0]?.message?.tool_calls || [], + }, + finish_reason: apiResponse.choices[0]?.finish_reason || "stop", + }, + ], + usage: { + prompt_tokens: apiResponse.usage?.prompt_tokens || 0, + completion_tokens: apiResponse.usage?.completion_tokens || 0, + total_tokens: apiResponse.usage?.total_tokens || 0, + }, + }; + + logger({ + category: "google", + message: "response", + level: 1, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (options.response_model) { + // First try standard function calling format + const toolCall = response.choices[0]?.message?.tool_calls?.[0]; + if (toolCall?.function?.arguments) { + try { + const result = JSON.parse(toolCall.function.arguments); + if (this.enableCaching) { + this.cache.set(cacheOptions, result, options.requestId); + } + return result as T; + } catch (e) { + // If JSON parse fails, the model might be returning a different format + logger({ + category: "google", + message: "failed to parse tool call arguments as JSON, retrying", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + }, + }); + } + } + + // If we have content but no tool calls, try to parse the content as JSON + const content = response.choices[0]?.message?.content; + if (content) { + try { + // Try to extract JSON from the content + const jsonMatch = content.match(/\{[\s\S]*\}/); + if (jsonMatch) { + const result = JSON.parse(jsonMatch[0]); + if (this.enableCaching) { + this.cache.set(cacheOptions, result, options.requestId); + } + return result as T; + } + } catch (e) { + logger({ + category: "google", + message: "failed to parse content as JSON", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + }, + }); + } + } + + // If we still haven't found valid JSON and have retries left, try again + if (!retries || retries < 5) { + return this.createChatCompletion({ + options, + logger, + retries: (retries ?? 0) + 1, + }); + } + + throw new Error( + "Create Chat Completion Failed: Could not extract valid JSON from response", + ); + } + + if (this.enableCaching) { + this.cache.set(cacheOptions, response, options.requestId); + } + + return response as T; + } catch (error) { + logger({ + category: "google", + message: "error creating chat completion", + level: 1, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + throw error; + } + } +} diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index a23e9ee94..d0760a141 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -81,7 +81,7 @@ export interface CreateChatCompletionOptions { } export abstract class LLMClient { - public type: "openai" | "anthropic" | "cerebras" | string; + public type: "openai" | "anthropic" | "cerebras" | "google" | string; public modelName: AvailableModel; public hasVision: boolean; public clientOptions: ClientOptions; diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 33d71bf54..796e710ff 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -7,6 +7,7 @@ import { import { LLMCache } from "../cache/LLMCache"; import { AnthropicClient } from "./AnthropicClient"; import { CerebrasClient } from "./CerebrasClient"; +import { GoogleClient } from "./GoogleClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; @@ -23,6 +24,8 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "claude-3-7-sonnet-20250219": "anthropic", "cerebras-llama-3.3-70b": "cerebras", "cerebras-llama-3.1-8b": "cerebras", + "gemini-2.0-flash": "google", + "gemini-2.0-flash-lite": "google", }; export class LLMProvider { @@ -89,6 +92,14 @@ export class LLMProvider { modelName, clientOptions, }); + case "google": + return new GoogleClient({ + logger: this.logger, + enableCaching: this.enableCaching, + cache: this.cache, + modelName, + clientOptions, + }); default: throw new Error(`Unsupported provider: ${provider}`); } diff --git a/types/model.ts b/types/model.ts index 8d1a6b49a..56ff91984 100644 --- a/types/model.ts +++ b/types/model.ts @@ -15,11 +15,13 @@ export const AvailableModelSchema = z.enum([ "o3-mini", "cerebras-llama-3.3-70b", "cerebras-llama-3.1-8b", + "gemini-2.0-flash", + "gemini-2.0-flash-lite", ]); export type AvailableModel = z.infer; -export type ModelProvider = "openai" | "anthropic" | "cerebras"; +export type ModelProvider = "openai" | "anthropic" | "cerebras" | "google"; export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions; From 9edf4fe04e08948e433a03c736f9a05b0c78cdd4 Mon Sep 17 00:00:00 2001 From: Anirudh Kamath Date: Fri, 28 Feb 2025 09:06:39 -0800 Subject: [PATCH 2/2] changeset --- .changeset/pink-fans-sparkle.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/pink-fans-sparkle.md diff --git a/.changeset/pink-fans-sparkle.md b/.changeset/pink-fans-sparkle.md new file mode 100644 index 000000000..76de6ffa2 --- /dev/null +++ b/.changeset/pink-fans-sparkle.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +Add native support for Gemini via OpenAI API