Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions packages/types/src/custom-models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { z } from "zod"

/**
* Schema for custom model information
* Defines the properties that can be specified for custom models
*/
export const customModelInfoSchema = z.object({
maxTokens: z.number().positive().optional(),
contextWindow: z.number().positive(),
supportsImages: z.boolean().optional(),
supportsPromptCache: z.boolean(), // Required in ModelInfo
supportsTemperature: z.boolean().optional(),
inputPrice: z.number().nonnegative().optional(),
outputPrice: z.number().nonnegative().optional(),
cacheWritesPrice: z.number().nonnegative().optional(),
cacheReadsPrice: z.number().nonnegative().optional(),
description: z.string().optional(),
supportsReasoningEffort: z.boolean().optional(),
supportsReasoningBudget: z.boolean().optional(),
requiredReasoningBudget: z.boolean().optional(),
reasoningEffort: z.string().optional(),
})

/**
* Schema for a custom models file
* The file is a simple record of model IDs to model information
* The provider is determined by the filename (e.g., openrouter.json)
*/
export const customModelsFileSchema = z.record(z.string(), customModelInfoSchema)

/**
* Type for the content of a custom models file
*/
export type CustomModelsFile = z.infer<typeof customModelsFileSchema>

/**
* Type for custom model information
*/
export type CustomModelInfo = z.infer<typeof customModelInfoSchema>
1 change: 1 addition & 0 deletions packages/types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from "./api.js"
export * from "./cloud.js"
export * from "./codebase-index.js"
export * from "./cookie-consent.js"
export * from "./custom-models.js"
export * from "./events.js"
export * from "./experiment.js"
export * from "./followup.js"
Expand Down
6 changes: 4 additions & 2 deletions src/api/providers/anthropic-vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"
import { ApiStream } from "../transform/stream"
import { addCacheBreakpoints } from "../transform/caching/vertex"
import { getModelParams } from "../transform/model-params"
import { getProviderModelsSync } from "./model-lookup"

import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
Expand Down Expand Up @@ -164,8 +165,9 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple

getModel() {
const modelId = this.options.apiModelId
let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
const info: ModelInfo = vertexModels[id]
const models = getProviderModelsSync("vertex", vertexModels as Record<string, ModelInfo>)
let id = modelId && modelId in models ? (modelId as VertexModelId) : vertexDefaultModelId
const info: ModelInfo = models[id]
const params = getModelParams({ format: "anthropic", modelId: id, model: info, settings: this.options })

// The `:thinking` suffix indicates that the model is a "Hybrid"
Expand Down
6 changes: 4 additions & 2 deletions src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { ApiHandlerOptions } from "../../shared/api"

import { ApiStream } from "../transform/stream"
import { getModelParams } from "../transform/model-params"
import { getProviderModelsSync } from "./model-lookup"

import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
Expand Down Expand Up @@ -249,8 +250,9 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa

getModel() {
const modelId = this.options.apiModelId
let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
let info: ModelInfo = anthropicModels[id]
const models = getProviderModelsSync("anthropic", anthropicModels as Record<string, ModelInfo>)
let id = modelId && modelId in models ? (modelId as AnthropicModelId) : anthropicDefaultModelId
let info: ModelInfo = models[id]

// If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info
if ((id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") && this.options.anthropicBeta1MContext) {
Expand Down
16 changes: 10 additions & 6 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {

import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { getProviderModelsSync } from "./model-lookup"
import { logger } from "../../utils/logging"
import { Package } from "../../shared/package"
import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
Expand Down Expand Up @@ -899,27 +900,30 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH

//Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name
getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: ModelInfo } {
// Try to find the model in bedrockModels
// Get merged models (static + custom)
const models = getProviderModelsSync("bedrock", bedrockModels as Record<string, ModelInfo>)

// Try to find the model in merged models
const baseModelId = this.parseBaseModelId(modelId) as BedrockModelId

let model
if (baseModelId in bedrockModels) {
if (baseModelId in models) {
//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
// The models array is a constant and updating the model ID from the returned invokedModelID value
// in a prompt router response isn't possible on the constant.
model = { id: baseModelId, info: JSON.parse(JSON.stringify(bedrockModels[baseModelId])) }
model = { id: baseModelId, info: JSON.parse(JSON.stringify(models[baseModelId])) }
} else if (modelType && modelType.includes("router")) {
model = {
id: bedrockDefaultPromptRouterModelId,
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
info: JSON.parse(JSON.stringify(models[bedrockDefaultPromptRouterModelId])),
}
} else {
// Use heuristics for model info, then allow overrides from ProviderSettings
const guessed = this.guessModelInfoFromId(modelId)
model = {
id: bedrockDefaultModelId,
info: {
...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
...JSON.parse(JSON.stringify(models[bedrockDefaultModelId])),
...guessed,
},
}
Expand Down
222 changes: 222 additions & 0 deletions src/api/providers/fetchers/__tests__/modelCache.custom.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"
import { getModels, flushModels } from "../modelCache"
import * as customModels from "../../../../services/custom-models"
import * as openrouter from "../openrouter"

// Mock file data storage
const mockReadFileData: Record<string, any> = {}

// Mock the custom models service
vi.mock("../../../../services/custom-models", () => ({
getCustomModelsForProvider: vi.fn(),
}))

// Mock the openrouter fetcher
vi.mock("../openrouter", () => ({
getOpenRouterModels: vi.fn(),
}))

// Mock other dependencies
vi.mock("../../../../utils/path", () => ({
getWorkspacePath: vi.fn(() => "/test/workspace"),
}))

vi.mock("../../../../core/config/ContextProxy", () => ({
ContextProxy: {
instance: {
globalStorageUri: {
fsPath: "/test/storage",
},
},
},
}))

vi.mock("../../../../utils/storage", () => ({
getCacheDirectoryPath: vi.fn(() => "/test/cache"),
}))

// Mock safeWriteJson to populate our mock file data
vi.mock("../../../../utils/safeWriteJson", () => ({
safeWriteJson: vi.fn((filePath: string, data: any) => {
mockReadFileData[filePath] = data
return Promise.resolve()
}),
}))

// Mock fs.readFile to return the models that were written
vi.mock("fs/promises", () => ({
default: {
readFile: vi.fn((filePath: string) => {
const data = mockReadFileData[filePath]
if (!data) throw new Error("File not found")
return Promise.resolve(JSON.stringify(data))
}),
},
readFile: vi.fn((filePath: string) => {
const data = mockReadFileData[filePath]
if (!data) throw new Error("File not found")
return Promise.resolve(JSON.stringify(data))
}),
}))

vi.mock("../../../../utils/fs", () => ({
fileExistsAtPath: vi.fn((filePath: string) => {
return Promise.resolve(filePath in mockReadFileData)
}),
}))

describe("Model Cache with Custom Models", () => {
beforeEach(async () => {
vi.clearAllMocks()
// Clear both memory cache and mock file cache before each test
await flushModels("openrouter")
// Clear the mock file cache
Object.keys(mockReadFileData).forEach((key) => delete mockReadFileData[key])
})

afterEach(() => {
vi.restoreAllMocks()
})

it("should merge custom models with provider-fetched models", async () => {
const providerModels = {
"openai/gpt-4": {
maxTokens: 8000,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
},
}

const customModelDefs = {
"custom/my-model": {
maxTokens: 4096,
contextWindow: 32000,
supportsPromptCache: false,
description: "My custom model",
},
}

vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)

const result = await getModels({ provider: "openrouter" })

expect(result).toEqual({
...providerModels,
...customModelDefs,
})
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)
expect(customModels.getCustomModelsForProvider).toHaveBeenCalledWith("openrouter", "/test/workspace")
})

it("should allow custom models to override provider models", async () => {
const providerModels = {
"openai/gpt-4": {
maxTokens: 8000,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
},
}

const customModelDefs = {
"openai/gpt-4": {
maxTokens: 16000, // Override max tokens
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
description: "Custom GPT-4 with higher token limit",
},
}

vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)

const result = await getModels({ provider: "openrouter" })

expect(result["openai/gpt-4"]).toEqual(customModelDefs["openai/gpt-4"])
expect(result["openai/gpt-4"].maxTokens).toBe(16000)
})

it("should handle empty custom models gracefully", async () => {
const providerModels = {
"openai/gpt-4": {
maxTokens: 8000,
contextWindow: 128000,
supportsPromptCache: false,
},
}

vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})

const result = await getModels({ provider: "openrouter" })

expect(result).toEqual(providerModels)
})

it("should work when provider returns no models", async () => {
const customModelDefs = {
"custom/model-1": {
maxTokens: 4096,
contextWindow: 32000,
supportsPromptCache: false,
},
}

vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce({})
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)

const result = await getModels({ provider: "openrouter" })

expect(result).toEqual(customModelDefs)
})

it("should handle errors in custom models loading gracefully", async () => {
const providerModels = {
"openai/gpt-4": {
maxTokens: 8000,
contextWindow: 128000,
supportsPromptCache: false,
},
}

vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockRejectedValueOnce(
new Error("Failed to load custom models"),
)

// The error in loading custom models should cause the overall fetch to fail
await expect(getModels({ provider: "openrouter" })).rejects.toThrow("Failed to load custom models")
})

it("should flush cache for specific provider", async () => {
const providerModels = {
"openai/gpt-4": {
maxTokens: 8000,
contextWindow: 128000,
supportsPromptCache: false,
},
}

// First call - should fetch
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})
await getModels({ provider: "openrouter" })
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)

// Second call - should use cache (no new mocks needed)
await getModels({ provider: "openrouter" })
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)

// Flush cache
await flushModels("openrouter")

// Third call - should fetch again (set up mock again)
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})
await getModels({ provider: "openrouter" })
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(2)
})
})
6 changes: 6 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import { ContextProxy } from "../../../core/config/ContextProxy"
import { getCacheDirectoryPath } from "../../../utils/storage"
import type { RouterName, ModelRecord } from "../../../shared/api"
import { fileExistsAtPath } from "../../../utils/fs"
import { getCustomModelsForProvider } from "../../../services/custom-models"
import { getWorkspacePath } from "../../../utils/path"

import { getOpenRouterModels } from "./openrouter"
import { getVercelAiGatewayModels } from "./vercel-ai-gateway"
Expand Down Expand Up @@ -118,6 +120,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
}
}

// Load and merge custom models
const customModels = await getCustomModelsForProvider(provider, getWorkspacePath())
models = { ...models, ...customModels }

// Cache the fetched models (even if empty, to signify a successful fetch with no models).
memoryCache.set(provider, models)

Expand Down
Loading
Loading