Skip to content

Commit 0d91250

Browse files
authored
refactor(js/plugins/vertexai): Dynamic model support for Gemini models (#2824)
refactor(js/plugins/googleai): Dynamic model support for Gemini models
1 parent f923a0b commit 0d91250

File tree

10 files changed

+251
-139
lines changed

10 files changed

+251
-139
lines changed

js/core/src/registry.ts

+7-6
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ export class Registry {
241241
}
242242
return cached;
243243
},
244-
resolver: async (actionType: ActionType, target: string) => {
244+
resolver: async (actionType: ActionType, actionName: string) => {
245245
if (provider.resolver) {
246-
await provider.resolver(actionType, target);
246+
await provider.resolver(actionType, actionName);
247247
}
248248
},
249249
};
@@ -261,19 +261,20 @@ export class Registry {
261261
/**
262262
* Resolves a new Action dynamically by registering it.
263263
* @param pluginName The name of the plugin
264-
* @param action
264+
* @param actionType The type of the action
265+
* @param actionName The name of the action
265266
* @returns
266267
*/
267268
async resolvePluginAction(
268269
pluginName: string,
269-
action: ActionType,
270-
target: string
270+
actionType: ActionType,
271+
actionName: string
271272
) {
272273
const plugin = this.pluginsByName[pluginName];
273274
if (plugin) {
274275
return await runOutsideActionRuntimeContext(this, async () => {
275276
if (plugin.resolver) {
276-
await plugin.resolver(action, target);
277+
await plugin.resolver(actionType, actionName);
277278
}
278279
});
279280
}

js/plugins/googleai/src/gemini.ts

+1
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ function nearestGeminiModelRef(
430430
version,
431431
});
432432
}
433+
433434
return GENERIC_GEMINI_MODEL.withConfig({ ...options, version });
434435
}
435436

js/plugins/googleai/src/index.ts

+117-71
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import { Genkit, ModelReference } from 'genkit';
1818
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
19+
import { ActionType } from 'genkit/registry';
1920
import {
2021
SUPPORTED_MODELS as EMBEDDER_MODELS,
2122
defineGoogleAIEmbedder,
@@ -81,82 +82,127 @@ export interface PluginOptions {
8182
experimental_debugTraces?: boolean;
8283
}
8384

84-
/**
85-
* Google Gemini Developer API plugin.
86-
*/
87-
export function googleAI(options?: PluginOptions): GenkitPlugin {
88-
return genkitPlugin('googleai', async (ai: Genkit) => {
89-
let apiVersions = ['v1'];
85+
async function initializer(ai: Genkit, options?: PluginOptions) {
86+
let apiVersions = ['v1'];
9087

91-
if (options?.apiVersion) {
92-
if (Array.isArray(options?.apiVersion)) {
93-
apiVersions = options?.apiVersion;
94-
} else {
95-
apiVersions = [options?.apiVersion];
96-
}
88+
if (options?.apiVersion) {
89+
if (Array.isArray(options?.apiVersion)) {
90+
apiVersions = options?.apiVersion;
91+
} else {
92+
apiVersions = [options?.apiVersion];
9793
}
94+
}
9895

99-
if (apiVersions.includes('v1beta')) {
100-
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
101-
defineGoogleAIModel({
102-
ai,
103-
name,
104-
apiKey: options?.apiKey,
105-
apiVersion: 'v1beta',
106-
baseUrl: options?.baseUrl,
107-
debugTraces: options?.experimental_debugTraces,
108-
})
109-
);
110-
}
111-
if (apiVersions.includes('v1')) {
112-
Object.keys(SUPPORTED_V1_MODELS).forEach((name) =>
113-
defineGoogleAIModel({
114-
ai,
115-
name,
116-
apiKey: options?.apiKey,
117-
apiVersion: undefined,
118-
baseUrl: options?.baseUrl,
119-
debugTraces: options?.experimental_debugTraces,
120-
})
121-
);
122-
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
123-
defineGoogleAIModel({
124-
ai,
125-
name,
126-
apiKey: options?.apiKey,
127-
apiVersion: undefined,
128-
baseUrl: options?.baseUrl,
129-
debugTraces: options?.experimental_debugTraces,
130-
})
131-
);
132-
Object.keys(EMBEDDER_MODELS).forEach((name) =>
133-
defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey })
134-
);
135-
}
96+
if (apiVersions.includes('v1beta')) {
97+
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
98+
defineGoogleAIModel({
99+
ai,
100+
name,
101+
apiKey: options?.apiKey,
102+
apiVersion: 'v1beta',
103+
baseUrl: options?.baseUrl,
104+
debugTraces: options?.experimental_debugTraces,
105+
})
106+
);
107+
}
108+
if (apiVersions.includes('v1')) {
109+
Object.keys(SUPPORTED_V1_MODELS).forEach((name) =>
110+
defineGoogleAIModel({
111+
ai,
112+
name,
113+
apiKey: options?.apiKey,
114+
apiVersion: undefined,
115+
baseUrl: options?.baseUrl,
116+
debugTraces: options?.experimental_debugTraces,
117+
})
118+
);
119+
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
120+
defineGoogleAIModel({
121+
ai,
122+
name,
123+
apiKey: options?.apiKey,
124+
apiVersion: undefined,
125+
baseUrl: options?.baseUrl,
126+
debugTraces: options?.experimental_debugTraces,
127+
})
128+
);
129+
Object.keys(EMBEDDER_MODELS).forEach((name) =>
130+
defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey })
131+
);
132+
}
136133

137-
if (options?.models) {
138-
for (const modelOrRef of options?.models) {
139-
const modelName =
140-
typeof modelOrRef === 'string'
141-
? modelOrRef
142-
: // strip out the `googleai/` prefix
143-
modelOrRef.name.split('/')[1];
144-
const modelRef =
145-
typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef;
146-
defineGoogleAIModel({
147-
ai,
148-
name: modelName,
149-
apiKey: options?.apiKey,
150-
baseUrl: options?.baseUrl,
151-
info: {
152-
...modelRef.info,
153-
label: `Google AI - ${modelName}`,
154-
},
155-
debugTraces: options?.experimental_debugTraces,
156-
});
157-
}
134+
if (options?.models) {
135+
for (const modelOrRef of options?.models) {
136+
const modelName =
137+
typeof modelOrRef === 'string'
138+
? modelOrRef
139+
: // strip out the `googleai/` prefix
140+
modelOrRef.name.split('/')[1];
141+
const modelRef =
142+
typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef;
143+
defineGoogleAIModel({
144+
ai,
145+
name: modelName,
146+
apiKey: options?.apiKey,
147+
baseUrl: options?.baseUrl,
148+
info: {
149+
...modelRef.info,
150+
label: `Google AI - ${modelName}`,
151+
},
152+
debugTraces: options?.experimental_debugTraces,
153+
});
158154
}
159-
});
155+
}
156+
}
157+
158+
async function resolver(
159+
ai: Genkit,
160+
actionType: ActionType,
161+
actionName: string,
162+
options?: PluginOptions
163+
) {
164+
// TODO: also support other actions like 'embedder'
165+
switch (actionType) {
166+
case 'model':
167+
await resolveModel(ai, actionName, options);
168+
break;
169+
default:
170+
// no-op
171+
}
172+
}
173+
174+
async function resolveModel(
175+
ai: Genkit,
176+
actionName: string,
177+
options?: PluginOptions
178+
) {
179+
if (actionName.includes('gemini')) {
180+
const modelRef = gemini(actionName);
181+
defineGoogleAIModel({
182+
ai,
183+
name: modelRef.name,
184+
apiKey: options?.apiKey,
185+
baseUrl: options?.baseUrl,
186+
info: {
187+
...modelRef.info,
188+
label: `Google AI - ${actionName}`,
189+
},
190+
debugTraces: options?.experimental_debugTraces,
191+
});
192+
}
193+
// TODO: Support other models
194+
}
195+
196+
/**
197+
* Google Gemini Developer API plugin.
198+
*/
199+
export function googleAI(options?: PluginOptions): GenkitPlugin {
200+
return genkitPlugin(
201+
'googleai',
202+
async (ai: Genkit) => await initializer(ai, options),
203+
async (ai: Genkit, actionType: ActionType, actionName: string) =>
204+
await resolver(ai, actionType, actionName, options)
205+
);
160206
}
161207

162208
export default googleAI;

js/plugins/googleai/tests/gemini_test.ts

+13-7
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ describe('plugin', () => {
463463
assert.strictEqual(flash.__action.name, 'googleai/gemini-1.5-flash');
464464
});
465465

466-
it('references explicitly registered models', async () => {
466+
it('references both pre-registered or dynamic models', async () => {
467467
const flash002Ref = gemini('gemini-1.5-flash-002');
468468
const ai = genkit({
469469
plugins: [
@@ -525,13 +525,19 @@ describe('plugin', () => {
525525
GENERIC_GEMINI_MODEL.info! // <---- generic model fallback
526526
);
527527

528-
// this one is not registered
529-
const flash003Ref = gemini('gemini-1.5-flash-003');
530-
assert.strictEqual(flash003Ref.name, 'googleai/gemini-1.5-flash-003');
531-
const flash003 = await ai.registry.lookupAction(
532-
`/model/${flash003Ref.name}`
528+
// this one is dynamically resolved (not pre-registered)
529+
const giraffeRef = gemini('gemini-4.5-giraffe');
530+
assert.strictEqual(giraffeRef.name, 'googleai/gemini-4.5-giraffe');
531+
const giraffe = await ai.registry.lookupAction(
532+
`/model/${giraffeRef.name}`
533+
);
534+
assert.ok(giraffe);
535+
assert.strictEqual(giraffe.__action.name, 'googleai/gemini-4.5-giraffe');
536+
assertEqualModelInfo(
537+
giraffe.__action.metadata?.model,
538+
'Google AI - gemini-4.5-giraffe',
539+
GENERIC_GEMINI_MODEL.info! // <---- generic model fallback
533540
);
534-
assert.ok(flash003 === undefined);
535541
});
536542
});
537543
});

0 commit comments

Comments
 (0)