Skip to content

Commit 5de3431

Browse files
committed
Support custom model configuration files
1 parent f93aafe commit 5de3431

File tree

25 files changed

+1594
-169
lines changed

25 files changed

+1594
-169
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { z } from "zod"
2+
3+
/**
4+
* Schema for custom model information
5+
* Defines the properties that can be specified for custom models
6+
*/
7+
export const customModelInfoSchema = z.object({
8+
maxTokens: z.number().positive().optional(),
9+
contextWindow: z.number().positive(),
10+
supportsImages: z.boolean().optional(),
11+
supportsPromptCache: z.boolean(), // Required in ModelInfo
12+
supportsTemperature: z.boolean().optional(),
13+
inputPrice: z.number().nonnegative().optional(),
14+
outputPrice: z.number().nonnegative().optional(),
15+
cacheWritesPrice: z.number().nonnegative().optional(),
16+
cacheReadsPrice: z.number().nonnegative().optional(),
17+
description: z.string().optional(),
18+
supportsReasoningEffort: z.boolean().optional(),
19+
supportsReasoningBudget: z.boolean().optional(),
20+
requiredReasoningBudget: z.boolean().optional(),
21+
reasoningEffort: z.string().optional(),
22+
})
23+
24+
/**
25+
* Schema for a custom models file
26+
* The file is a simple record of model IDs to model information
27+
* The provider is determined by the filename (e.g., openrouter.json)
28+
*/
29+
export const customModelsFileSchema = z.record(z.string(), customModelInfoSchema)
30+
31+
/**
32+
* Type for the content of a custom models file
33+
*/
34+
export type CustomModelsFile = z.infer<typeof customModelsFileSchema>
35+
36+
/**
37+
* Type for custom model information
38+
*/
39+
export type CustomModelInfo = z.infer<typeof customModelInfoSchema>

packages/types/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export * from "./api.js"
22
export * from "./cloud.js"
33
export * from "./codebase-index.js"
44
export * from "./cookie-consent.js"
5+
export * from "./custom-models.js"
56
export * from "./events.js"
67
export * from "./experiment.js"
78
export * from "./followup.js"

src/api/providers/anthropic-vertex.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"
1616
import { ApiStream } from "../transform/stream"
1717
import { addCacheBreakpoints } from "../transform/caching/vertex"
1818
import { getModelParams } from "../transform/model-params"
19+
import { getProviderModelsSync } from "./model-lookup"
1920

2021
import { BaseProvider } from "./base-provider"
2122
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -164,8 +165,9 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
164165

165166
getModel() {
166167
const modelId = this.options.apiModelId
167-
let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
168-
const info: ModelInfo = vertexModels[id]
168+
const models = getProviderModelsSync("vertex", vertexModels as Record<string, ModelInfo>)
169+
let id = modelId && modelId in models ? (modelId as VertexModelId) : vertexDefaultModelId
170+
const info: ModelInfo = models[id]
169171
const params = getModelParams({ format: "anthropic", modelId: id, model: info, settings: this.options })
170172

171173
// The `:thinking` suffix indicates that the model is a "Hybrid"

src/api/providers/anthropic.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import type { ApiHandlerOptions } from "../../shared/api"
1414

1515
import { ApiStream } from "../transform/stream"
1616
import { getModelParams } from "../transform/model-params"
17+
import { getProviderModelsSync } from "./model-lookup"
1718

1819
import { BaseProvider } from "./base-provider"
1920
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -249,8 +250,9 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
249250

250251
getModel() {
251252
const modelId = this.options.apiModelId
252-
let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
253-
let info: ModelInfo = anthropicModels[id]
253+
const models = getProviderModelsSync("anthropic", anthropicModels as Record<string, ModelInfo>)
254+
let id = modelId && modelId in models ? (modelId as AnthropicModelId) : anthropicDefaultModelId
255+
let info: ModelInfo = models[id]
254256

255257
// If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info
256258
if ((id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") && this.options.anthropicBeta1MContext) {

src/api/providers/bedrock.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import {
2727

2828
import { ApiStream } from "../transform/stream"
2929
import { BaseProvider } from "./base-provider"
30+
import { getProviderModelsSync } from "./model-lookup"
3031
import { logger } from "../../utils/logging"
3132
import { Package } from "../../shared/package"
3233
import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
@@ -899,27 +900,30 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
899900

900901
//Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name
901902
getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: ModelInfo } {
902-
// Try to find the model in bedrockModels
903+
// Get merged models (static + custom)
904+
const models = getProviderModelsSync("bedrock", bedrockModels as Record<string, ModelInfo>)
905+
906+
// Try to find the model in merged models
903907
const baseModelId = this.parseBaseModelId(modelId) as BedrockModelId
904908

905909
let model
906-
if (baseModelId in bedrockModels) {
910+
if (baseModelId in models) {
907911
//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
908-
// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
912+
// The models array is a constant and updating the model ID from the returned invokedModelID value
909913
// in a prompt router response isn't possible on the constant.
910-
model = { id: baseModelId, info: JSON.parse(JSON.stringify(bedrockModels[baseModelId])) }
914+
model = { id: baseModelId, info: JSON.parse(JSON.stringify(models[baseModelId])) }
911915
} else if (modelType && modelType.includes("router")) {
912916
model = {
913917
id: bedrockDefaultPromptRouterModelId,
914-
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
918+
info: JSON.parse(JSON.stringify(models[bedrockDefaultPromptRouterModelId])),
915919
}
916920
} else {
917921
// Use heuristics for model info, then allow overrides from ProviderSettings
918922
const guessed = this.guessModelInfoFromId(modelId)
919923
model = {
920924
id: bedrockDefaultModelId,
921925
info: {
922-
...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
926+
...JSON.parse(JSON.stringify(models[bedrockDefaultModelId])),
923927
...guessed,
924928
},
925929
}
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"
2+
import { getModels, flushModels } from "../modelCache"
3+
import * as customModels from "../../../../services/custom-models"
4+
import * as openrouter from "../openrouter"
5+
6+
// Mock file data storage
7+
const mockReadFileData: Record<string, any> = {}
8+
9+
// Mock the custom models service
10+
vi.mock("../../../../services/custom-models", () => ({
11+
getCustomModelsForProvider: vi.fn(),
12+
}))
13+
14+
// Mock the openrouter fetcher
15+
vi.mock("../openrouter", () => ({
16+
getOpenRouterModels: vi.fn(),
17+
}))
18+
19+
// Mock other dependencies
20+
vi.mock("../../../../utils/path", () => ({
21+
getWorkspacePath: vi.fn(() => "/test/workspace"),
22+
}))
23+
24+
vi.mock("../../../../core/config/ContextProxy", () => ({
25+
ContextProxy: {
26+
instance: {
27+
globalStorageUri: {
28+
fsPath: "/test/storage",
29+
},
30+
},
31+
},
32+
}))
33+
34+
vi.mock("../../../../utils/storage", () => ({
35+
getCacheDirectoryPath: vi.fn(() => "/test/cache"),
36+
}))
37+
38+
// Mock safeWriteJson to populate our mock file data
39+
vi.mock("../../../../utils/safeWriteJson", () => ({
40+
safeWriteJson: vi.fn((filePath: string, data: any) => {
41+
mockReadFileData[filePath] = data
42+
return Promise.resolve()
43+
}),
44+
}))
45+
46+
// Mock fs.readFile to return the models that were written
47+
vi.mock("fs/promises", () => ({
48+
default: {
49+
readFile: vi.fn((filePath: string) => {
50+
const data = mockReadFileData[filePath]
51+
if (!data) throw new Error("File not found")
52+
return Promise.resolve(JSON.stringify(data))
53+
}),
54+
},
55+
readFile: vi.fn((filePath: string) => {
56+
const data = mockReadFileData[filePath]
57+
if (!data) throw new Error("File not found")
58+
return Promise.resolve(JSON.stringify(data))
59+
}),
60+
}))
61+
62+
vi.mock("../../../../utils/fs", () => ({
63+
fileExistsAtPath: vi.fn((filePath: string) => {
64+
return Promise.resolve(filePath in mockReadFileData)
65+
}),
66+
}))
67+
68+
describe("Model Cache with Custom Models", () => {
69+
beforeEach(async () => {
70+
vi.clearAllMocks()
71+
// Clear both memory cache and mock file cache before each test
72+
await flushModels("openrouter")
73+
// Clear the mock file cache
74+
Object.keys(mockReadFileData).forEach((key) => delete mockReadFileData[key])
75+
})
76+
77+
afterEach(() => {
78+
vi.restoreAllMocks()
79+
})
80+
81+
it("should merge custom models with provider-fetched models", async () => {
82+
const providerModels = {
83+
"openai/gpt-4": {
84+
maxTokens: 8000,
85+
contextWindow: 128000,
86+
supportsImages: true,
87+
supportsPromptCache: false,
88+
},
89+
}
90+
91+
const customModelDefs = {
92+
"custom/my-model": {
93+
maxTokens: 4096,
94+
contextWindow: 32000,
95+
supportsPromptCache: false,
96+
description: "My custom model",
97+
},
98+
}
99+
100+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
101+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)
102+
103+
const result = await getModels({ provider: "openrouter" })
104+
105+
expect(result).toEqual({
106+
...providerModels,
107+
...customModelDefs,
108+
})
109+
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)
110+
expect(customModels.getCustomModelsForProvider).toHaveBeenCalledWith("openrouter", "/test/workspace")
111+
})
112+
113+
it("should allow custom models to override provider models", async () => {
114+
const providerModels = {
115+
"openai/gpt-4": {
116+
maxTokens: 8000,
117+
contextWindow: 128000,
118+
supportsImages: true,
119+
supportsPromptCache: false,
120+
},
121+
}
122+
123+
const customModelDefs = {
124+
"openai/gpt-4": {
125+
maxTokens: 16000, // Override max tokens
126+
contextWindow: 128000,
127+
supportsImages: true,
128+
supportsPromptCache: false,
129+
description: "Custom GPT-4 with higher token limit",
130+
},
131+
}
132+
133+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
134+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)
135+
136+
const result = await getModels({ provider: "openrouter" })
137+
138+
expect(result["openai/gpt-4"]).toEqual(customModelDefs["openai/gpt-4"])
139+
expect(result["openai/gpt-4"].maxTokens).toBe(16000)
140+
})
141+
142+
it("should handle empty custom models gracefully", async () => {
143+
const providerModels = {
144+
"openai/gpt-4": {
145+
maxTokens: 8000,
146+
contextWindow: 128000,
147+
supportsPromptCache: false,
148+
},
149+
}
150+
151+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
152+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})
153+
154+
const result = await getModels({ provider: "openrouter" })
155+
156+
expect(result).toEqual(providerModels)
157+
})
158+
159+
it("should work when provider returns no models", async () => {
160+
const customModelDefs = {
161+
"custom/model-1": {
162+
maxTokens: 4096,
163+
contextWindow: 32000,
164+
supportsPromptCache: false,
165+
},
166+
}
167+
168+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce({})
169+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce(customModelDefs)
170+
171+
const result = await getModels({ provider: "openrouter" })
172+
173+
expect(result).toEqual(customModelDefs)
174+
})
175+
176+
it("should handle errors in custom models loading gracefully", async () => {
177+
const providerModels = {
178+
"openai/gpt-4": {
179+
maxTokens: 8000,
180+
contextWindow: 128000,
181+
supportsPromptCache: false,
182+
},
183+
}
184+
185+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
186+
vi.mocked(customModels.getCustomModelsForProvider).mockRejectedValueOnce(
187+
new Error("Failed to load custom models"),
188+
)
189+
190+
// The error in loading custom models should cause the overall fetch to fail
191+
await expect(getModels({ provider: "openrouter" })).rejects.toThrow("Failed to load custom models")
192+
})
193+
194+
it("should flush cache for specific provider", async () => {
195+
const providerModels = {
196+
"openai/gpt-4": {
197+
maxTokens: 8000,
198+
contextWindow: 128000,
199+
supportsPromptCache: false,
200+
},
201+
}
202+
203+
// First call - should fetch
204+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
205+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})
206+
await getModels({ provider: "openrouter" })
207+
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)
208+
209+
// Second call - should use cache (no new mocks needed)
210+
await getModels({ provider: "openrouter" })
211+
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(1)
212+
213+
// Flush cache
214+
await flushModels("openrouter")
215+
216+
// Third call - should fetch again (set up mock again)
217+
vi.mocked(openrouter.getOpenRouterModels).mockResolvedValueOnce(providerModels)
218+
vi.mocked(customModels.getCustomModelsForProvider).mockResolvedValueOnce({})
219+
await getModels({ provider: "openrouter" })
220+
expect(openrouter.getOpenRouterModels).toHaveBeenCalledTimes(2)
221+
})
222+
})

src/api/providers/fetchers/modelCache.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import { ContextProxy } from "../../../core/config/ContextProxy"
1111
import { getCacheDirectoryPath } from "../../../utils/storage"
1212
import type { RouterName, ModelRecord } from "../../../shared/api"
1313
import { fileExistsAtPath } from "../../../utils/fs"
14+
import { getCustomModelsForProvider } from "../../../services/custom-models"
15+
import { getWorkspacePath } from "../../../utils/path"
1416

1517
import { getOpenRouterModels } from "./openrouter"
1618
import { getVercelAiGatewayModels } from "./vercel-ai-gateway"
@@ -118,6 +120,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
118120
}
119121
}
120122

123+
// Load and merge custom models
124+
const customModels = await getCustomModelsForProvider(provider, getWorkspacePath())
125+
models = { ...models, ...customModels }
126+
121127
// Cache the fetched models (even if empty, to signify a successful fetch with no models).
122128
memoryCache.set(provider, models)
123129

0 commit comments

Comments
 (0)