Skip to content

feat(js/plugins/ollama): implemented dynamic model resolver and list actions for ollama plugin #2831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 30, 2025
250 changes: 208 additions & 42 deletions js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,37 @@
* limitations under the License.
*/

import { Genkit, ToolRequest, ToolRequestPart, ToolResponse, z } from 'genkit';
import {
ActionMetadata,
embedderRef,
Genkit,
modelActionMetadata,
ToolRequest,
ToolRequestPart,
ToolResponse,
z,
type EmbedderReference,
type ModelReference,
} from 'genkit';
import { logger } from 'genkit/logging';
import {
GenerateRequest,
GenerateResponseData,
GenerationCommonConfigDescriptions,
GenerationCommonConfigSchema,
getBasicUsageStats,
MessageData,
ModelInfo,
modelRef,
ToolDefinition,
getBasicUsageStats,
} from 'genkit/model';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import { defineOllamaEmbedder } from './embeddings.js';
import {
ApiType,
ListLocalModelsResponse,
LocalModel,
Message,
ModelDefinition,
OllamaTool,
Expand All @@ -39,25 +55,136 @@ import {

export { type OllamaPluginParams };

export type OllamaPlugin = {
(params?: OllamaPluginParams): GenkitPlugin;

model(
name: string,
config?: z.infer<typeof OllamaConfigSchema>
): ModelReference<typeof OllamaConfigSchema>;
embedder(name: string, config?: Record<string, any>): EmbedderReference;
};

const ANY_JSON_SCHEMA: Record<string, any> = {
$schema: 'http://json-schema.org/draft-07/schema#',
};

export function ollama(params: OllamaPluginParams): GenkitPlugin {
return genkitPlugin('ollama', async (ai: Genkit) => {
const serverAddress = params.serverAddress;
params.models?.map((model) =>
ollamaModel(ai, model, serverAddress, params.requestHeaders)
);
params.embedders?.map((model) =>
defineOllamaEmbedder(ai, {
name: model.name,
modelName: model.name,
dimensions: model.dimensions,
options: params,
})
const GENERIC_MODEL_INFO = {
supports: {
multiturn: true,
media: true,
tools: true,
toolChoice: true,
systemRole: true,
constrained: 'all',
},
} as ModelInfo;

const DEFAULT_OLLAMA_SERVER_ADDRESS = 'http://localhost:11434';

async function initializer(
ai: Genkit,
serverAddress: string,
params?: OllamaPluginParams
) {
params?.models?.map((model) =>
defineOllamaModel(ai, model, serverAddress, params?.requestHeaders)
);
params?.embedders?.map((model) =>
defineOllamaEmbedder(ai, {
name: model.name,
modelName: model.name,
dimensions: model.dimensions,
options: params!,
})
);
}

function resolveAction(
ai: Genkit,
actionType: ActionType,
actionName: string,
serverAddress: string,
requestHeaders?: RequestHeaders
) {
// We can only dynamically resolve models, for embedders user must provide dimensions.
if (actionType === 'model') {
defineOllamaModel(
ai,
{
name: actionName,
},
serverAddress,
requestHeaders
);
});
}
}

async function listActions(
serverAddress: string,
requestHeaders?: RequestHeaders
): Promise<ActionMetadata[]> {
const models = await listLocalModels(serverAddress, requestHeaders);
return (
models
// naively filter out embedders, unfortunately there's no better way.
?.filter((m) => m.model && !m.model.includes('embed'))
.map((m) =>
modelActionMetadata({
name: `ollama/${m.model}`,
info: GENERIC_MODEL_INFO,
})
) || []
);
}

function ollamaPlugin(params?: OllamaPluginParams): GenkitPlugin {
if (!params) {
params = {};
}
if (!params.serverAddress) {
params.serverAddress = DEFAULT_OLLAMA_SERVER_ADDRESS;
}
const serverAddress = params.serverAddress;
return genkitPlugin(
'ollama',
async (ai: Genkit) => {
await initializer(ai, serverAddress, params);
},
async (ai, actionType, actionName) => {
resolveAction(
ai,
actionType,
actionName,
serverAddress,
params?.requestHeaders
);
},
async () => await listActions(serverAddress, params?.requestHeaders)
);
}

async function listLocalModels(
serverAddress: string,
requestHeaders?: RequestHeaders
): Promise<LocalModel[]> {
// We call the ollama list local models api: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
let res;
try {
res = await fetch(serverAddress + '/api/tags', {
method: 'GET',
headers: {
'Content-Type': 'application/json',
...(await getHeaders(serverAddress, requestHeaders)),
},
});
} catch (e) {
throw new Error(`Make sure the Ollama server is running.`, {
cause: e,
});
}
const modelResponse = JSON.parse(await res.text()) as ListLocalModelsResponse;
return modelResponse.models;
}

/**
Expand Down Expand Up @@ -92,7 +219,7 @@ export const OllamaConfigSchema = GenerationCommonConfigSchema.extend({
.optional(),
});

function ollamaModel(
function defineOllamaModel(
ai: Genkit,
model: ModelDefinition,
serverAddress: string,
Expand All @@ -110,21 +237,20 @@ function ollamaModel(
},
},
async (input, streamingCallback) => {
const options: Record<string, any> = {};
if (input.config?.temperature !== undefined) {
options.temperature = input.config.temperature;
}
if (input.config?.topP !== undefined) {
options.top_p = input.config.topP;
const { topP, topK, stopSequences, maxOutputTokens, ...rest } =
input.config as any;
const options: Record<string, any> = { ...rest };
if (topP !== undefined) {
options.top_p = topP;
}
if (input.config?.topK !== undefined) {
options.top_k = input.config.topK;
if (topK !== undefined) {
options.top_k = topK;
}
if (input.config?.stopSequences !== undefined) {
options.stop = input.config.stopSequences.join('');
if (stopSequences !== undefined) {
options.stop = stopSequences.join('');
}
if (input.config?.maxOutputTokens !== undefined) {
options.num_predict = input.config.maxOutputTokens;
if (maxOutputTokens !== undefined) {
options.num_predict = maxOutputTokens;
}
const type = model.type ?? 'chat';
const request = toOllamaRequest(
Expand All @@ -136,18 +262,12 @@ function ollamaModel(
);
logger.debug(request, `ollama request (${type})`);

const extraHeaders = requestHeaders
? typeof requestHeaders === 'function'
? await requestHeaders(
{
serverAddress,
model,
},
input
)
: requestHeaders
: {};

const extraHeaders = await getHeaders(
serverAddress,
requestHeaders,
model,
input
);
let res;
try {
res = await fetch(
Expand Down Expand Up @@ -252,6 +372,25 @@ function parseMessage(response: any, type: ApiType): MessageData {
}
}

async function getHeaders(
serverAddress: string,
requestHeaders?: RequestHeaders,
model?: ModelDefinition,
input?: GenerateRequest
): Promise<Record<string, string> | void> {
return requestHeaders
? typeof requestHeaders === 'function'
? await requestHeaders(
{
serverAddress,
model,
},
input
)
: requestHeaders
: {};
}

function toOllamaRequest(
name: string,
input: GenerateRequest,
Expand All @@ -278,7 +417,13 @@ function toOllamaRequest(
messageText += c.text;
}
if (c.media) {
images.push(c.media.url);
let imageUri = c.media.url;
// ollama doesn't accept full data URIs, just the base64 encoded image,
// strip out data URI prefix (ex. `data:image/jpeg;base64,`)
if (imageUri.startsWith('data:')) {
imageUri = imageUri.substring(imageUri.indexOf(',') + 1);
}
images.push(imageUri);
}
if (c.toolRequest) {
toolRequests.push(c.toolRequest);
Expand Down Expand Up @@ -391,3 +536,24 @@ function isValidOllamaTool(tool: ToolDefinition): boolean {
}
return true;
}

export const ollama = ollamaPlugin as OllamaPlugin;
ollama.model = (
name: string,
config?: z.infer<typeof OllamaConfigSchema>
): ModelReference<typeof OllamaConfigSchema> => {
return modelRef({
name: `ollama/${name}`,
config,
configSchema: OllamaConfigSchema,
});
};
ollama.embedder = (
name: string,
config?: Record<string, any>
): EmbedderReference => {
return embedderRef({
name: `ollama/${name}`,
config,
});
};
29 changes: 26 additions & 3 deletions js/plugins/ollama/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ export interface OllamaPluginParams {
embedders?: EmbeddingModelDefinition[];

/**
* The address of the Ollama server.
* The address of the Ollama server. Default: http://localhost:11434
*/
serverAddress: string;
serverAddress?: string;

/**
* Optional request headers, which can be either static or dynamically generated.
Expand Down Expand Up @@ -116,7 +116,7 @@ export interface RequestHeaderFunction {
(
params: {
serverAddress: string;
model: ModelDefinition | EmbeddingModelDefinition;
model?: ModelDefinition | EmbeddingModelDefinition;
modelRequest?: GenerateRequest;
embedRequest?: EmbedRequest;
},
Expand Down Expand Up @@ -157,3 +157,26 @@ export interface Message {
images?: string[];
tool_calls?: any[];
}

// Ollama local model definition
export interface LocalModel {
name: string;
model: string;
// ISO 8601 format date
modified_at: string;
size: number;
digest: string;
details?: {
parent_model?: string;
format?: string;
family?: string;
families?: string[];
parameter_size?: string;
quantization_level?: string;
};
}

// Ollama list local models response
export interface ListLocalModelsResponse {
models: LocalModel[];
}
Loading