From d925df6828b722cd9f7687ca73158d9a31a356eb Mon Sep 17 00:00:00 2001 From: Ingrid Fielker Date: Wed, 23 Apr 2025 18:46:23 -0400 Subject: [PATCH 1/3] feat(js/plugins/vertexai): Dynamic model support for Gemini models --- js/core/src/registry.ts | 13 +-- js/plugins/vertexai/src/index.ts | 135 ++++++++++++++++-------- js/testapps/flow-simple-ai/package.json | 3 +- js/testapps/flow-simple-ai/src/index.ts | 8 +- package.json | 1 + pnpm-lock.yaml | 3 + 6 files changed, 109 insertions(+), 54 deletions(-) diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 14ae232f6c..4993b39383 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -241,9 +241,9 @@ export class Registry { } return cached; }, - resolver: async (actionType: ActionType, target: string) => { + resolver: async (actionType: ActionType, actionName: string) => { if (provider.resolver) { - await provider.resolver(actionType, target); + await provider.resolver(actionType, actionName); } }, }; @@ -261,19 +261,20 @@ export class Registry { /** * Resolves a new Action dynamically by registering it. * @param pluginName The name of the plugin - * @param action + * @param actionType The type of the action + * @param actionName The name of the action * @returns */ async resolvePluginAction( pluginName: string, - action: ActionType, - target: string + actionType: ActionType, + actionName: string ) { const plugin = this.pluginsByName[pluginName]; if (plugin) { return await runOutsideActionRuntimeContext(this, async () => { if (plugin.resolver) { - await plugin.resolver(action, target); + await plugin.resolver(actionType, actionName); } }); } diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 830031a021..b994fa0147 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -22,6 +22,7 @@ import { Genkit } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; +import { ActionType } from 'genkit/registry'; import { getDerivedParams } from './common/index.js'; import { PluginOptions } from './common/types.js'; import { @@ -85,57 +86,105 @@ export { type GeminiConfig, }; -/** - * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. - */ -export function vertexAI(options?: PluginOptions): GenkitPlugin { - return genkitPlugin('vertexai', async (ai: Genkit) => { - const { projectId, location, vertexClientFactory, authClient } = - await getDerivedParams(options); +async function initializer(ai: Genkit, options?: PluginOptions) { + const { projectId, location, vertexClientFactory, authClient } = + await getDerivedParams(options); - Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => - imagenModel(ai, name, authClient, { projectId, location }) - ); - Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - defineGeminiKnownModel( + Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => + imagenModel(ai, name, authClient, { projectId, location }) + ); + Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => + defineGeminiKnownModel( + ai, + name, + vertexClientFactory, + { + projectId, + location, + }, + options?.experimental_debugTraces + ) + ); + if (options?.models) { + for (const modelOrRef of options?.models) { + const modelName = + typeof modelOrRef === 'string' + ? modelOrRef + : // strip out the `vertexai/` prefix + modelOrRef.name.split('/')[1]; + const modelRef = + typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; + defineGeminiModel({ ai, - name, + modelName: modelRef.name, + version: modelName, + modelInfo: modelRef.info, vertexClientFactory, - { + options: { projectId, location, }, - options?.experimental_debugTraces - ) - ); - if (options?.models) { - for (const modelOrRef of options?.models) { - const modelName = - typeof modelOrRef === 'string' - ? modelOrRef - : // strip out the `vertexai/` prefix - modelOrRef.name.split('/')[1]; - const modelRef = - typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; - defineGeminiModel({ - ai, - modelName: modelRef.name, - version: modelName, - modelInfo: modelRef.info, - vertexClientFactory, - options: { - projectId, - location, - }, - debugTraces: options.experimental_debugTraces, - }); - } + debugTraces: options.experimental_debugTraces, + }); } + } + + Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => + defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) + ); +} + +async function resolver( + ai: Genkit, + action: ActionType, + target: string, + options?: PluginOptions +) { + // TODO: also support other actions like 'embedder' + switch (action) { + case 'model': + await resolveModel(ai, target, options); + break; + default: + // no-op + } +} + +async function resolveModel( + ai: Genkit, + target: string, + options?: PluginOptions +) { + const { projectId, location, vertexClientFactory } = + await getDerivedParams(options); + if (target.includes('gemini')) { + const modelRef = gemini(target); + defineGeminiModel({ + ai, + modelName: modelRef.name, + version: target, + modelInfo: modelRef.info, + vertexClientFactory, + options: { + projectId, + location, + }, + debugTraces: options?.experimental_debugTraces, + }); + } + // TODO: Support other models +} - Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => - defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) - ); - }); +/** + * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. + */ +export function vertexAI(options?: PluginOptions): GenkitPlugin { + return genkitPlugin( + 'vertexai', + async (ai: Genkit) => await initializer(ai, options), + async (ai: Genkit, action: ActionType, target: string) => + await resolver(ai, action, target, options) + ); } export default vertexAI; diff --git a/js/testapps/flow-simple-ai/package.json b/js/testapps/flow-simple-ai/package.json index dc91b820a3..d209ba24af 100644 --- a/js/testapps/flow-simple-ai/package.json +++ b/js/testapps/flow-simple-ai/package.json @@ -9,7 +9,8 @@ "build": "pnpm build:clean && pnpm compile", "build:clean": "rimraf ./lib", "build:watch": "tsc --watch", - "build-and-run": "pnpm build && node lib/index.js" + "build-and-run": "pnpm build && node lib/index.js", + "genkit:dev": "genkit start -- tsx --watch src/index.ts" }, "keywords": [], "author": "", diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 2d40070ab6..0adfe575e1 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -104,9 +104,9 @@ export const jokeFlow = ai.defineFlow( { name: 'jokeFlow', inputSchema: z.object({ - modelName: z.string(), - modelVersion: z.string().optional(), - subject: z.string(), + modelName: z.string().default('vertexai/gemini-2.5-pro-exp-03-25'), + modelVersion: z.string().optional().default('gemini-2.5-pro-exp-03-25'), + subject: z.string().default('bananas'), }), outputSchema: z.string(), }, @@ -740,7 +740,7 @@ ai.defineFlow('formatJsonManualSchema', async (input, { sendChunk }) => { const { output, text } = await ai.generate({ model: gemini15Flash, prompt: `generate one RPG game character of type ${input || 'archer'} and generated JSON must match this interface - + \`\`\`typescript interface Character { name: string; diff --git a/package.json b/package.json index 0f4061916a..3855e7945b 100644 --- a/package.json +++ b/package.json @@ -30,6 +30,7 @@ "format:check" ], "devDependencies": { + "@types/node": "22.10.5", "inquirer": "^8.0.0", "npm-run-all": "^4.1.5", "only-allow": "^1.2.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2a3c90e10b..67c07f0abd 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -16,6 +16,9 @@ importers: specifier: link:js/genkit version: link:js/genkit devDependencies: + '@types/node': + specifier: 22.10.5 + version: 22.10.5 inquirer: specifier: ^8.0.0 version: 8.2.6 From d2286fd783d591cc73f465c0a5fbad2882ce856a Mon Sep 17 00:00:00 2001 From: Ingrid Fielker Date: Fri, 25 Apr 2025 14:04:42 -0400 Subject: [PATCH 2/3] test(js/plugins/vertexai): modified the test --- js/plugins/vertexai/tests/plugin_test.ts | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/js/plugins/vertexai/tests/plugin_test.ts b/js/plugins/vertexai/tests/plugin_test.ts index 9fc731b707..4d4fcf76dd 100644 --- a/js/plugins/vertexai/tests/plugin_test.ts +++ b/js/plugins/vertexai/tests/plugin_test.ts @@ -62,7 +62,7 @@ describe('plugin', () => { assert.strictEqual(flash.__action.name, 'vertexai/gemini-1.5-flash'); }); - it('references explicitly registered models', async () => { + it('references both pre-registered or dynamic models', async () => { const flash002Ref = gemini('gemini-1.5-flash-002'); const ai = genkit({ plugins: [ @@ -122,13 +122,17 @@ describe('plugin', () => { GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); - // this one is not registered - const flash003Ref = gemini('gemini-1.5-flash-003'); - assert.strictEqual(flash003Ref.name, 'vertexai/gemini-1.5-flash-003'); - const flash003 = await ai.registry.lookupAction( - `/model/${flash003Ref.name}` + // this one is dynamically resolved (not pre-registered) + const giraffeRef = gemini('gemini-4.5-giraffe'); + assert.strictEqual(giraffeRef.name, 'vertexai/gemini-4.5-giraffe'); + const giraffe = await ai.registry.lookupAction(`/model/${giraffeRef.name}`); + assert.ok(giraffe); + assert.strictEqual(giraffe.__action.name, 'vertexai/gemini-4.5-giraffe'); + assertEqualModelInfo( + giraffe.__action.metadata?.model, + 'Vertex AI - gemini-4.5-giraffe', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); - assert.ok(flash003 === undefined); }); }); From 8861f983f60af880367a1a739c80d06f2c04dd51 Mon Sep 17 00:00:00 2001 From: Ingrid Fielker Date: Fri, 25 Apr 2025 16:01:50 -0400 Subject: [PATCH 3/3] feat(js/plugins/googleai) Dynamic model support for Gemini models --- js/plugins/googleai/src/gemini.ts | 1 + js/plugins/googleai/src/index.ts | 188 ++++++++++++++--------- js/plugins/googleai/tests/gemini_test.ts | 20 ++- js/plugins/vertexai/src/index.ts | 20 +-- 4 files changed, 141 insertions(+), 88 deletions(-) diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 8299f68365..3b287f5dbc 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -430,6 +430,7 @@ function nearestGeminiModelRef( version, }); } + return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); } diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index 0c8cbe135d..09d89629a9 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -16,6 +16,7 @@ import { Genkit, ModelReference } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; +import { ActionType } from 'genkit/registry'; import { SUPPORTED_MODELS as EMBEDDER_MODELS, defineGoogleAIEmbedder, @@ -81,82 +82,127 @@ export interface PluginOptions { experimental_debugTraces?: boolean; } -/** - * Google Gemini Developer API plugin. - */ -export function googleAI(options?: PluginOptions): GenkitPlugin { - return genkitPlugin('googleai', async (ai: Genkit) => { - let apiVersions = ['v1']; +async function initializer(ai: Genkit, options?: PluginOptions) { + let apiVersions = ['v1']; - if (options?.apiVersion) { - if (Array.isArray(options?.apiVersion)) { - apiVersions = options?.apiVersion; - } else { - apiVersions = [options?.apiVersion]; - } + if (options?.apiVersion) { + if (Array.isArray(options?.apiVersion)) { + apiVersions = options?.apiVersion; + } else { + apiVersions = [options?.apiVersion]; } + } - if (apiVersions.includes('v1beta')) { - Object.keys(SUPPORTED_V15_MODELS).forEach((name) => - defineGoogleAIModel({ - ai, - name, - apiKey: options?.apiKey, - apiVersion: 'v1beta', - baseUrl: options?.baseUrl, - debugTraces: options?.experimental_debugTraces, - }) - ); - } - if (apiVersions.includes('v1')) { - Object.keys(SUPPORTED_V1_MODELS).forEach((name) => - defineGoogleAIModel({ - ai, - name, - apiKey: options?.apiKey, - apiVersion: undefined, - baseUrl: options?.baseUrl, - debugTraces: options?.experimental_debugTraces, - }) - ); - Object.keys(SUPPORTED_V15_MODELS).forEach((name) => - defineGoogleAIModel({ - ai, - name, - apiKey: options?.apiKey, - apiVersion: undefined, - baseUrl: options?.baseUrl, - debugTraces: options?.experimental_debugTraces, - }) - ); - Object.keys(EMBEDDER_MODELS).forEach((name) => - defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) - ); - } + if (apiVersions.includes('v1beta')) { + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel({ + ai, + name, + apiKey: options?.apiKey, + apiVersion: 'v1beta', + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) + ); + } + if (apiVersions.includes('v1')) { + Object.keys(SUPPORTED_V1_MODELS).forEach((name) => + defineGoogleAIModel({ + ai, + name, + apiKey: options?.apiKey, + apiVersion: undefined, + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) + ); + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel({ + ai, + name, + apiKey: options?.apiKey, + apiVersion: undefined, + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) + ); + Object.keys(EMBEDDER_MODELS).forEach((name) => + defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) + ); + } - if (options?.models) { - for (const modelOrRef of options?.models) { - const modelName = - typeof modelOrRef === 'string' - ? modelOrRef - : // strip out the `googleai/` prefix - modelOrRef.name.split('/')[1]; - const modelRef = - typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; - defineGoogleAIModel({ - ai, - name: modelName, - apiKey: options?.apiKey, - baseUrl: options?.baseUrl, - info: { - ...modelRef.info, - label: `Google AI - ${modelName}`, - }, - debugTraces: options?.experimental_debugTraces, - }); - } + if (options?.models) { + for (const modelOrRef of options?.models) { + const modelName = + typeof modelOrRef === 'string' + ? modelOrRef + : // strip out the `googleai/` prefix + modelOrRef.name.split('/')[1]; + const modelRef = + typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; + defineGoogleAIModel({ + ai, + name: modelName, + apiKey: options?.apiKey, + baseUrl: options?.baseUrl, + info: { + ...modelRef.info, + label: `Google AI - ${modelName}`, + }, + debugTraces: options?.experimental_debugTraces, + }); } - }); + } +} + +async function resolver( + ai: Genkit, + actionType: ActionType, + actionName: string, + options?: PluginOptions +) { + // TODO: also support other actions like 'embedder' + switch (actionType) { + case 'model': + await resolveModel(ai, actionName, options); + break; + default: + // no-op + } +} + +async function resolveModel( + ai: Genkit, + actionName: string, + options?: PluginOptions +) { + if (actionName.includes('gemini')) { + const modelRef = gemini(actionName); + defineGoogleAIModel({ + ai, + name: modelRef.name, + apiKey: options?.apiKey, + baseUrl: options?.baseUrl, + info: { + ...modelRef.info, + label: `Google AI - ${actionName}`, + }, + debugTraces: options?.experimental_debugTraces, + }); + } + // TODO: Support other models +} + +/** + * Google Gemini Developer API plugin. + */ +export function googleAI(options?: PluginOptions): GenkitPlugin { + return genkitPlugin( + 'googleai', + async (ai: Genkit) => await initializer(ai, options), + async (ai: Genkit, actionType: ActionType, actionName: string) => + await resolver(ai, actionType, actionName, options) + ); } export default googleAI; diff --git a/js/plugins/googleai/tests/gemini_test.ts b/js/plugins/googleai/tests/gemini_test.ts index 85a7768e6a..45d2ad8732 100644 --- a/js/plugins/googleai/tests/gemini_test.ts +++ b/js/plugins/googleai/tests/gemini_test.ts @@ -463,7 +463,7 @@ describe('plugin', () => { assert.strictEqual(flash.__action.name, 'googleai/gemini-1.5-flash'); }); - it('references explicitly registered models', async () => { + it('references both pre-registered or dynamic models', async () => { const flash002Ref = gemini('gemini-1.5-flash-002'); const ai = genkit({ plugins: [ @@ -525,13 +525,19 @@ describe('plugin', () => { GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); - // this one is not registered - const flash003Ref = gemini('gemini-1.5-flash-003'); - assert.strictEqual(flash003Ref.name, 'googleai/gemini-1.5-flash-003'); - const flash003 = await ai.registry.lookupAction( - `/model/${flash003Ref.name}` + // this one is dynamically resolved (not pre-registered) + const giraffeRef = gemini('gemini-4.5-giraffe'); + assert.strictEqual(giraffeRef.name, 'googleai/gemini-4.5-giraffe'); + const giraffe = await ai.registry.lookupAction( + `/model/${giraffeRef.name}` + ); + assert.ok(giraffe); + assert.strictEqual(giraffe.__action.name, 'googleai/gemini-4.5-giraffe'); + assertEqualModelInfo( + giraffe.__action.metadata?.model, + 'Google AI - gemini-4.5-giraffe', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); - assert.ok(flash003 === undefined); }); }); }); diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index b994fa0147..aba1ecb1cd 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -136,14 +136,14 @@ async function initializer(ai: Genkit, options?: PluginOptions) { async function resolver( ai: Genkit, - action: ActionType, - target: string, + actionType: ActionType, + actionName: string, options?: PluginOptions ) { // TODO: also support other actions like 'embedder' - switch (action) { + switch (actionType) { case 'model': - await resolveModel(ai, target, options); + await resolveModel(ai, actionName, options); break; default: // no-op @@ -152,17 +152,17 @@ async function resolver( async function resolveModel( ai: Genkit, - target: string, + actionName: string, options?: PluginOptions ) { const { projectId, location, vertexClientFactory } = await getDerivedParams(options); - if (target.includes('gemini')) { - const modelRef = gemini(target); + if (actionName.includes('gemini')) { + const modelRef = gemini(actionName); defineGeminiModel({ ai, modelName: modelRef.name, - version: target, + version: actionName, modelInfo: modelRef.info, vertexClientFactory, options: { @@ -182,8 +182,8 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { return genkitPlugin( 'vertexai', async (ai: Genkit) => await initializer(ai, options), - async (ai: Genkit, action: ActionType, target: string) => - await resolver(ai, action, target, options) + async (ai: Genkit, actionType: ActionType, actionName: string) => + await resolver(ai, actionType, actionName, options) ); }