Skip to content

Commit 2343d84

Browse files
committed
feat(js/plugins/ollama): implemented dynamic model resolver and list actions for ollama plugin
1 parent f060923 commit 2343d84

File tree

2 files changed

+137
-40
lines changed

2 files changed

+137
-40
lines changed

js/plugins/ollama/src/index.ts

+136-39
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { Genkit, ToolRequest, ToolRequestPart, ToolResponse, z } from 'genkit';
17+
import {
18+
ActionMetadata,
19+
Genkit,
20+
ToolRequest,
21+
ToolRequestPart,
22+
ToolResponse,
23+
z,
24+
} from 'genkit';
1825
import { logger } from 'genkit/logging';
1926
import {
2027
GenerateRequest,
2128
GenerateResponseData,
2229
GenerationCommonConfigDescriptions,
2330
GenerationCommonConfigSchema,
2431
MessageData,
32+
ModelInfo,
2533
ToolDefinition,
2634
getBasicUsageStats,
2735
} from 'genkit/model';
@@ -43,21 +51,83 @@ const ANY_JSON_SCHEMA: Record<string, any> = {
4351
$schema: 'http://json-schema.org/draft-07/schema#',
4452
};
4553

54+
const GENERIC_MODEL_INFO = {
55+
supports: {
56+
multiturn: true,
57+
media: true,
58+
tools: true,
59+
toolChoice: true,
60+
systemRole: true,
61+
constrained: 'all',
62+
},
63+
} as ModelInfo;
64+
4665
export function ollama(params: OllamaPluginParams): GenkitPlugin {
47-
return genkitPlugin('ollama', async (ai: Genkit) => {
48-
const serverAddress = params.serverAddress;
49-
params.models?.map((model) =>
50-
ollamaModel(ai, model, serverAddress, params.requestHeaders)
51-
);
52-
params.embedders?.map((model) =>
53-
defineOllamaEmbedder(ai, {
54-
name: model.name,
55-
modelName: model.name,
56-
dimensions: model.dimensions,
57-
options: params,
58-
})
59-
);
60-
});
66+
const serverAddress = params.serverAddress;
67+
return genkitPlugin(
68+
'ollama',
69+
async (ai: Genkit) => {
70+
params.models?.map((model) =>
71+
ollamaModel(ai, model, serverAddress, params.requestHeaders)
72+
);
73+
params.embedders?.map((model) =>
74+
defineOllamaEmbedder(ai, {
75+
name: model.name,
76+
modelName: model.name,
77+
dimensions: model.dimensions,
78+
options: params,
79+
})
80+
);
81+
},
82+
async (ai, actionType, actionName) => {
83+
// We can only dynamically resolve models, for embedders user must provide dimensions.
84+
if (actionType === 'model') {
85+
ollamaModel(
86+
ai,
87+
{
88+
name: actionName,
89+
},
90+
serverAddress,
91+
params.requestHeaders
92+
);
93+
}
94+
},
95+
async () => {
96+
// We call the ollama list local models api: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
97+
let res;
98+
try {
99+
res = await fetch(serverAddress + '/api/tags', {
100+
method: 'GET',
101+
headers: {
102+
'Content-Type': 'application/json',
103+
...(await getHeaders(serverAddress, params.requestHeaders)),
104+
},
105+
});
106+
} catch (e) {
107+
throw new Error(`Make sure the Ollama server is running.`, {
108+
cause: e,
109+
});
110+
}
111+
const modelResponse = JSON.parse(await res.text());
112+
return (
113+
modelResponse?.models
114+
// naively filter out embedders, unfortunately there's no better way.
115+
?.filter((m) => m.model && !m.model.includes('embed'))
116+
.map(
117+
(m) =>
118+
({
119+
actionType: 'model',
120+
name: `ollama/${m.model}`,
121+
metadata: {
122+
model: {
123+
...GENERIC_MODEL_INFO,
124+
} as ModelInfo,
125+
},
126+
}) as ActionMetadata
127+
) || []
128+
);
129+
}
130+
);
61131
}
62132

63133
/**
@@ -110,21 +180,29 @@ function ollamaModel(
110180
},
111181
},
112182
async (input, streamingCallback) => {
113-
const options: Record<string, any> = {};
114-
if (input.config?.temperature !== undefined) {
115-
options.temperature = input.config.temperature;
183+
const {
184+
temperature,
185+
topP,
186+
topK,
187+
stopSequences,
188+
maxOutputTokens,
189+
...rest
190+
} = input.config as any;
191+
const options: Record<string, any> = { ...rest };
192+
if (temperature !== undefined) {
193+
options.temperature = temperature;
116194
}
117-
if (input.config?.topP !== undefined) {
118-
options.top_p = input.config.topP;
195+
if (topP !== undefined) {
196+
options.top_p = topP;
119197
}
120-
if (input.config?.topK !== undefined) {
121-
options.top_k = input.config.topK;
198+
if (topK !== undefined) {
199+
options.top_k = topK;
122200
}
123-
if (input.config?.stopSequences !== undefined) {
124-
options.stop = input.config.stopSequences.join('');
201+
if (stopSequences !== undefined) {
202+
options.stop = stopSequences.join('');
125203
}
126-
if (input.config?.maxOutputTokens !== undefined) {
127-
options.num_predict = input.config.maxOutputTokens;
204+
if (maxOutputTokens !== undefined) {
205+
options.num_predict = maxOutputTokens;
128206
}
129207
const type = model.type ?? 'chat';
130208
const request = toOllamaRequest(
@@ -136,18 +214,12 @@ function ollamaModel(
136214
);
137215
logger.debug(request, `ollama request (${type})`);
138216

139-
const extraHeaders = requestHeaders
140-
? typeof requestHeaders === 'function'
141-
? await requestHeaders(
142-
{
143-
serverAddress,
144-
model,
145-
},
146-
input
147-
)
148-
: requestHeaders
149-
: {};
150-
217+
const extraHeaders = await getHeaders(
218+
serverAddress,
219+
requestHeaders,
220+
model,
221+
input
222+
);
151223
let res;
152224
try {
153225
res = await fetch(
@@ -252,6 +324,25 @@ function parseMessage(response: any, type: ApiType): MessageData {
252324
}
253325
}
254326

327+
async function getHeaders(
328+
serverAddress: string,
329+
requestHeaders?: RequestHeaders,
330+
model?: ModelDefinition,
331+
input?: GenerateRequest
332+
): Promise<Record<string, string> | void> {
333+
return requestHeaders
334+
? typeof requestHeaders === 'function'
335+
? await requestHeaders(
336+
{
337+
serverAddress,
338+
model,
339+
},
340+
input
341+
)
342+
: requestHeaders
343+
: {};
344+
}
345+
255346
function toOllamaRequest(
256347
name: string,
257348
input: GenerateRequest,
@@ -278,7 +369,13 @@ function toOllamaRequest(
278369
messageText += c.text;
279370
}
280371
if (c.media) {
281-
images.push(c.media.url);
372+
let imageUri = c.media.url;
373+
// ollama doesn't accept full data URIs, just the base64 encoded image,
374+
// strip out data URI prefix (ex. `data:image/jpeg;base64,`)
375+
if (imageUri.startsWith('data:')) {
376+
imageUri = imageUri.substring(imageUri.indexOf(',') + 1);
377+
}
378+
images.push(imageUri);
282379
}
283380
if (c.toolRequest) {
284381
toolRequests.push(c.toolRequest);

js/plugins/ollama/src/types.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ export interface RequestHeaderFunction {
116116
(
117117
params: {
118118
serverAddress: string;
119-
model: ModelDefinition | EmbeddingModelDefinition;
119+
model?: ModelDefinition | EmbeddingModelDefinition;
120120
modelRequest?: GenerateRequest;
121121
embedRequest?: EmbedRequest;
122122
},

0 commit comments

Comments
 (0)