Skip to content

Commit 2023053

Browse files
authored
feat(js/plugins/ollama): implemented dynamic model resolver and list actions for ollama plugin (#2831)
1 parent ea6a84d commit 2023053

File tree

3 files changed

+257
-50
lines changed

3 files changed

+257
-50
lines changed

js/plugins/ollama/src/index.ts

+208-42
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,37 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { Genkit, ToolRequest, ToolRequestPart, ToolResponse, z } from 'genkit';
17+
import {
18+
ActionMetadata,
19+
embedderRef,
20+
Genkit,
21+
modelActionMetadata,
22+
ToolRequest,
23+
ToolRequestPart,
24+
ToolResponse,
25+
z,
26+
type EmbedderReference,
27+
type ModelReference,
28+
} from 'genkit';
1829
import { logger } from 'genkit/logging';
1930
import {
2031
GenerateRequest,
2132
GenerateResponseData,
2233
GenerationCommonConfigDescriptions,
2334
GenerationCommonConfigSchema,
35+
getBasicUsageStats,
2436
MessageData,
37+
ModelInfo,
38+
modelRef,
2539
ToolDefinition,
26-
getBasicUsageStats,
2740
} from 'genkit/model';
2841
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
42+
import { ActionType } from 'genkit/registry';
2943
import { defineOllamaEmbedder } from './embeddings.js';
3044
import {
3145
ApiType,
46+
ListLocalModelsResponse,
47+
LocalModel,
3248
Message,
3349
ModelDefinition,
3450
OllamaTool,
@@ -39,25 +55,136 @@ import {
3955

4056
export { type OllamaPluginParams };
4157

58+
export type OllamaPlugin = {
59+
(params?: OllamaPluginParams): GenkitPlugin;
60+
61+
model(
62+
name: string,
63+
config?: z.infer<typeof OllamaConfigSchema>
64+
): ModelReference<typeof OllamaConfigSchema>;
65+
embedder(name: string, config?: Record<string, any>): EmbedderReference;
66+
};
67+
4268
const ANY_JSON_SCHEMA: Record<string, any> = {
4369
$schema: 'http://json-schema.org/draft-07/schema#',
4470
};
4571

46-
export function ollama(params: OllamaPluginParams): GenkitPlugin {
47-
return genkitPlugin('ollama', async (ai: Genkit) => {
48-
const serverAddress = params.serverAddress;
49-
params.models?.map((model) =>
50-
ollamaModel(ai, model, serverAddress, params.requestHeaders)
51-
);
52-
params.embedders?.map((model) =>
53-
defineOllamaEmbedder(ai, {
54-
name: model.name,
55-
modelName: model.name,
56-
dimensions: model.dimensions,
57-
options: params,
58-
})
72+
const GENERIC_MODEL_INFO = {
73+
supports: {
74+
multiturn: true,
75+
media: true,
76+
tools: true,
77+
toolChoice: true,
78+
systemRole: true,
79+
constrained: 'all',
80+
},
81+
} as ModelInfo;
82+
83+
const DEFAULT_OLLAMA_SERVER_ADDRESS = 'http://localhost:11434';
84+
85+
async function initializer(
86+
ai: Genkit,
87+
serverAddress: string,
88+
params?: OllamaPluginParams
89+
) {
90+
params?.models?.map((model) =>
91+
defineOllamaModel(ai, model, serverAddress, params?.requestHeaders)
92+
);
93+
params?.embedders?.map((model) =>
94+
defineOllamaEmbedder(ai, {
95+
name: model.name,
96+
modelName: model.name,
97+
dimensions: model.dimensions,
98+
options: params!,
99+
})
100+
);
101+
}
102+
103+
function resolveAction(
104+
ai: Genkit,
105+
actionType: ActionType,
106+
actionName: string,
107+
serverAddress: string,
108+
requestHeaders?: RequestHeaders
109+
) {
110+
// We can only dynamically resolve models, for embedders user must provide dimensions.
111+
if (actionType === 'model') {
112+
defineOllamaModel(
113+
ai,
114+
{
115+
name: actionName,
116+
},
117+
serverAddress,
118+
requestHeaders
59119
);
60-
});
120+
}
121+
}
122+
123+
async function listActions(
124+
serverAddress: string,
125+
requestHeaders?: RequestHeaders
126+
): Promise<ActionMetadata[]> {
127+
const models = await listLocalModels(serverAddress, requestHeaders);
128+
return (
129+
models
130+
// naively filter out embedders, unfortunately there's no better way.
131+
?.filter((m) => m.model && !m.model.includes('embed'))
132+
.map((m) =>
133+
modelActionMetadata({
134+
name: `ollama/${m.model}`,
135+
info: GENERIC_MODEL_INFO,
136+
})
137+
) || []
138+
);
139+
}
140+
141+
function ollamaPlugin(params?: OllamaPluginParams): GenkitPlugin {
142+
if (!params) {
143+
params = {};
144+
}
145+
if (!params.serverAddress) {
146+
params.serverAddress = DEFAULT_OLLAMA_SERVER_ADDRESS;
147+
}
148+
const serverAddress = params.serverAddress;
149+
return genkitPlugin(
150+
'ollama',
151+
async (ai: Genkit) => {
152+
await initializer(ai, serverAddress, params);
153+
},
154+
async (ai, actionType, actionName) => {
155+
resolveAction(
156+
ai,
157+
actionType,
158+
actionName,
159+
serverAddress,
160+
params?.requestHeaders
161+
);
162+
},
163+
async () => await listActions(serverAddress, params?.requestHeaders)
164+
);
165+
}
166+
167+
async function listLocalModels(
168+
serverAddress: string,
169+
requestHeaders?: RequestHeaders
170+
): Promise<LocalModel[]> {
171+
// We call the ollama list local models api: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
172+
let res;
173+
try {
174+
res = await fetch(serverAddress + '/api/tags', {
175+
method: 'GET',
176+
headers: {
177+
'Content-Type': 'application/json',
178+
...(await getHeaders(serverAddress, requestHeaders)),
179+
},
180+
});
181+
} catch (e) {
182+
throw new Error(`Make sure the Ollama server is running.`, {
183+
cause: e,
184+
});
185+
}
186+
const modelResponse = JSON.parse(await res.text()) as ListLocalModelsResponse;
187+
return modelResponse.models;
61188
}
62189

63190
/**
@@ -92,7 +219,7 @@ export const OllamaConfigSchema = GenerationCommonConfigSchema.extend({
92219
.optional(),
93220
});
94221

95-
function ollamaModel(
222+
function defineOllamaModel(
96223
ai: Genkit,
97224
model: ModelDefinition,
98225
serverAddress: string,
@@ -110,21 +237,20 @@ function ollamaModel(
110237
},
111238
},
112239
async (input, streamingCallback) => {
113-
const options: Record<string, any> = {};
114-
if (input.config?.temperature !== undefined) {
115-
options.temperature = input.config.temperature;
116-
}
117-
if (input.config?.topP !== undefined) {
118-
options.top_p = input.config.topP;
240+
const { topP, topK, stopSequences, maxOutputTokens, ...rest } =
241+
input.config as any;
242+
const options: Record<string, any> = { ...rest };
243+
if (topP !== undefined) {
244+
options.top_p = topP;
119245
}
120-
if (input.config?.topK !== undefined) {
121-
options.top_k = input.config.topK;
246+
if (topK !== undefined) {
247+
options.top_k = topK;
122248
}
123-
if (input.config?.stopSequences !== undefined) {
124-
options.stop = input.config.stopSequences.join('');
249+
if (stopSequences !== undefined) {
250+
options.stop = stopSequences.join('');
125251
}
126-
if (input.config?.maxOutputTokens !== undefined) {
127-
options.num_predict = input.config.maxOutputTokens;
252+
if (maxOutputTokens !== undefined) {
253+
options.num_predict = maxOutputTokens;
128254
}
129255
const type = model.type ?? 'chat';
130256
const request = toOllamaRequest(
@@ -136,18 +262,12 @@ function ollamaModel(
136262
);
137263
logger.debug(request, `ollama request (${type})`);
138264

139-
const extraHeaders = requestHeaders
140-
? typeof requestHeaders === 'function'
141-
? await requestHeaders(
142-
{
143-
serverAddress,
144-
model,
145-
},
146-
input
147-
)
148-
: requestHeaders
149-
: {};
150-
265+
const extraHeaders = await getHeaders(
266+
serverAddress,
267+
requestHeaders,
268+
model,
269+
input
270+
);
151271
let res;
152272
try {
153273
res = await fetch(
@@ -252,6 +372,25 @@ function parseMessage(response: any, type: ApiType): MessageData {
252372
}
253373
}
254374

375+
async function getHeaders(
376+
serverAddress: string,
377+
requestHeaders?: RequestHeaders,
378+
model?: ModelDefinition,
379+
input?: GenerateRequest
380+
): Promise<Record<string, string> | void> {
381+
return requestHeaders
382+
? typeof requestHeaders === 'function'
383+
? await requestHeaders(
384+
{
385+
serverAddress,
386+
model,
387+
},
388+
input
389+
)
390+
: requestHeaders
391+
: {};
392+
}
393+
255394
function toOllamaRequest(
256395
name: string,
257396
input: GenerateRequest,
@@ -278,7 +417,13 @@ function toOllamaRequest(
278417
messageText += c.text;
279418
}
280419
if (c.media) {
281-
images.push(c.media.url);
420+
let imageUri = c.media.url;
421+
// ollama doesn't accept full data URIs, just the base64 encoded image,
422+
// strip out data URI prefix (ex. `data:image/jpeg;base64,`)
423+
if (imageUri.startsWith('data:')) {
424+
imageUri = imageUri.substring(imageUri.indexOf(',') + 1);
425+
}
426+
images.push(imageUri);
282427
}
283428
if (c.toolRequest) {
284429
toolRequests.push(c.toolRequest);
@@ -391,3 +536,24 @@ function isValidOllamaTool(tool: ToolDefinition): boolean {
391536
}
392537
return true;
393538
}
539+
540+
export const ollama = ollamaPlugin as OllamaPlugin;
541+
ollama.model = (
542+
name: string,
543+
config?: z.infer<typeof OllamaConfigSchema>
544+
): ModelReference<typeof OllamaConfigSchema> => {
545+
return modelRef({
546+
name: `ollama/${name}`,
547+
config,
548+
configSchema: OllamaConfigSchema,
549+
});
550+
};
551+
ollama.embedder = (
552+
name: string,
553+
config?: Record<string, any>
554+
): EmbedderReference => {
555+
return embedderRef({
556+
name: `ollama/${name}`,
557+
config,
558+
});
559+
};

js/plugins/ollama/src/types.ts

+26-3
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ export interface OllamaPluginParams {
8686
embedders?: EmbeddingModelDefinition[];
8787

8888
/**
89-
* The address of the Ollama server.
89+
* The address of the Ollama server. Default: http://localhost:11434
9090
*/
91-
serverAddress: string;
91+
serverAddress?: string;
9292

9393
/**
9494
* Optional request headers, which can be either static or dynamically generated.
@@ -116,7 +116,7 @@ export interface RequestHeaderFunction {
116116
(
117117
params: {
118118
serverAddress: string;
119-
model: ModelDefinition | EmbeddingModelDefinition;
119+
model?: ModelDefinition | EmbeddingModelDefinition;
120120
modelRequest?: GenerateRequest;
121121
embedRequest?: EmbedRequest;
122122
},
@@ -157,3 +157,26 @@ export interface Message {
157157
images?: string[];
158158
tool_calls?: any[];
159159
}
160+
161+
// Ollama local model definition
162+
export interface LocalModel {
163+
name: string;
164+
model: string;
165+
// ISO 8601 format date
166+
modified_at: string;
167+
size: number;
168+
digest: string;
169+
details?: {
170+
parent_model?: string;
171+
format?: string;
172+
family?: string;
173+
families?: string[];
174+
parameter_size?: string;
175+
quantization_level?: string;
176+
};
177+
}
178+
179+
// Ollama list local models response
180+
export interface ListLocalModelsResponse {
181+
models: LocalModel[];
182+
}

0 commit comments

Comments
 (0)