14
14
* limitations under the License.
15
15
*/
16
16
17
- import { Genkit , ToolRequest , ToolRequestPart , ToolResponse , z } from 'genkit' ;
17
+ import {
18
+ ActionMetadata ,
19
+ embedderRef ,
20
+ Genkit ,
21
+ modelActionMetadata ,
22
+ ToolRequest ,
23
+ ToolRequestPart ,
24
+ ToolResponse ,
25
+ z ,
26
+ type EmbedderReference ,
27
+ type ModelReference ,
28
+ } from 'genkit' ;
18
29
import { logger } from 'genkit/logging' ;
19
30
import {
20
31
GenerateRequest ,
21
32
GenerateResponseData ,
22
33
GenerationCommonConfigDescriptions ,
23
34
GenerationCommonConfigSchema ,
35
+ getBasicUsageStats ,
24
36
MessageData ,
37
+ ModelInfo ,
38
+ modelRef ,
25
39
ToolDefinition ,
26
- getBasicUsageStats ,
27
40
} from 'genkit/model' ;
28
41
import { GenkitPlugin , genkitPlugin } from 'genkit/plugin' ;
42
+ import { ActionType } from 'genkit/registry' ;
29
43
import { defineOllamaEmbedder } from './embeddings.js' ;
30
44
import {
31
45
ApiType ,
46
+ ListLocalModelsResponse ,
47
+ LocalModel ,
32
48
Message ,
33
49
ModelDefinition ,
34
50
OllamaTool ,
@@ -39,25 +55,136 @@ import {
39
55
40
56
export { type OllamaPluginParams } ;
41
57
58
+ export type OllamaPlugin = {
59
+ ( params ?: OllamaPluginParams ) : GenkitPlugin ;
60
+
61
+ model (
62
+ name : string ,
63
+ config ?: z . infer < typeof OllamaConfigSchema >
64
+ ) : ModelReference < typeof OllamaConfigSchema > ;
65
+ embedder ( name : string , config ?: Record < string , any > ) : EmbedderReference ;
66
+ } ;
67
+
42
68
const ANY_JSON_SCHEMA : Record < string , any > = {
43
69
$schema : 'http://json-schema.org/draft-07/schema#' ,
44
70
} ;
45
71
46
- export function ollama ( params : OllamaPluginParams ) : GenkitPlugin {
47
- return genkitPlugin ( 'ollama' , async ( ai : Genkit ) => {
48
- const serverAddress = params . serverAddress ;
49
- params . models ?. map ( ( model ) =>
50
- ollamaModel ( ai , model , serverAddress , params . requestHeaders )
51
- ) ;
52
- params . embedders ?. map ( ( model ) =>
53
- defineOllamaEmbedder ( ai , {
54
- name : model . name ,
55
- modelName : model . name ,
56
- dimensions : model . dimensions ,
57
- options : params ,
58
- } )
72
+ const GENERIC_MODEL_INFO = {
73
+ supports : {
74
+ multiturn : true ,
75
+ media : true ,
76
+ tools : true ,
77
+ toolChoice : true ,
78
+ systemRole : true ,
79
+ constrained : 'all' ,
80
+ } ,
81
+ } as ModelInfo ;
82
+
83
+ const DEFAULT_OLLAMA_SERVER_ADDRESS = 'http://localhost:11434' ;
84
+
85
+ async function initializer (
86
+ ai : Genkit ,
87
+ serverAddress : string ,
88
+ params ?: OllamaPluginParams
89
+ ) {
90
+ params ?. models ?. map ( ( model ) =>
91
+ defineOllamaModel ( ai , model , serverAddress , params ?. requestHeaders )
92
+ ) ;
93
+ params ?. embedders ?. map ( ( model ) =>
94
+ defineOllamaEmbedder ( ai , {
95
+ name : model . name ,
96
+ modelName : model . name ,
97
+ dimensions : model . dimensions ,
98
+ options : params ! ,
99
+ } )
100
+ ) ;
101
+ }
102
+
103
+ function resolveAction (
104
+ ai : Genkit ,
105
+ actionType : ActionType ,
106
+ actionName : string ,
107
+ serverAddress : string ,
108
+ requestHeaders ?: RequestHeaders
109
+ ) {
110
+ // We can only dynamically resolve models, for embedders user must provide dimensions.
111
+ if ( actionType === 'model' ) {
112
+ defineOllamaModel (
113
+ ai ,
114
+ {
115
+ name : actionName ,
116
+ } ,
117
+ serverAddress ,
118
+ requestHeaders
59
119
) ;
60
- } ) ;
120
+ }
121
+ }
122
+
123
+ async function listActions (
124
+ serverAddress : string ,
125
+ requestHeaders ?: RequestHeaders
126
+ ) : Promise < ActionMetadata [ ] > {
127
+ const models = await listLocalModels ( serverAddress , requestHeaders ) ;
128
+ return (
129
+ models
130
+ // naively filter out embedders, unfortunately there's no better way.
131
+ ?. filter ( ( m ) => m . model && ! m . model . includes ( 'embed' ) )
132
+ . map ( ( m ) =>
133
+ modelActionMetadata ( {
134
+ name : `ollama/${ m . model } ` ,
135
+ info : GENERIC_MODEL_INFO ,
136
+ } )
137
+ ) || [ ]
138
+ ) ;
139
+ }
140
+
141
+ function ollamaPlugin ( params ?: OllamaPluginParams ) : GenkitPlugin {
142
+ if ( ! params ) {
143
+ params = { } ;
144
+ }
145
+ if ( ! params . serverAddress ) {
146
+ params . serverAddress = DEFAULT_OLLAMA_SERVER_ADDRESS ;
147
+ }
148
+ const serverAddress = params . serverAddress ;
149
+ return genkitPlugin (
150
+ 'ollama' ,
151
+ async ( ai : Genkit ) => {
152
+ await initializer ( ai , serverAddress , params ) ;
153
+ } ,
154
+ async ( ai , actionType , actionName ) => {
155
+ resolveAction (
156
+ ai ,
157
+ actionType ,
158
+ actionName ,
159
+ serverAddress ,
160
+ params ?. requestHeaders
161
+ ) ;
162
+ } ,
163
+ async ( ) => await listActions ( serverAddress , params ?. requestHeaders )
164
+ ) ;
165
+ }
166
+
167
+ async function listLocalModels (
168
+ serverAddress : string ,
169
+ requestHeaders ?: RequestHeaders
170
+ ) : Promise < LocalModel [ ] > {
171
+ // We call the ollama list local models api: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
172
+ let res ;
173
+ try {
174
+ res = await fetch ( serverAddress + '/api/tags' , {
175
+ method : 'GET' ,
176
+ headers : {
177
+ 'Content-Type' : 'application/json' ,
178
+ ...( await getHeaders ( serverAddress , requestHeaders ) ) ,
179
+ } ,
180
+ } ) ;
181
+ } catch ( e ) {
182
+ throw new Error ( `Make sure the Ollama server is running.` , {
183
+ cause : e ,
184
+ } ) ;
185
+ }
186
+ const modelResponse = JSON . parse ( await res . text ( ) ) as ListLocalModelsResponse ;
187
+ return modelResponse . models ;
61
188
}
62
189
63
190
/**
@@ -92,7 +219,7 @@ export const OllamaConfigSchema = GenerationCommonConfigSchema.extend({
92
219
. optional ( ) ,
93
220
} ) ;
94
221
95
- function ollamaModel (
222
+ function defineOllamaModel (
96
223
ai : Genkit ,
97
224
model : ModelDefinition ,
98
225
serverAddress : string ,
@@ -110,21 +237,20 @@ function ollamaModel(
110
237
} ,
111
238
} ,
112
239
async ( input , streamingCallback ) => {
113
- const options : Record < string , any > = { } ;
114
- if ( input . config ?. temperature !== undefined ) {
115
- options . temperature = input . config . temperature ;
116
- }
117
- if ( input . config ?. topP !== undefined ) {
118
- options . top_p = input . config . topP ;
240
+ const { topP, topK, stopSequences, maxOutputTokens, ...rest } =
241
+ input . config as any ;
242
+ const options : Record < string , any > = { ...rest } ;
243
+ if ( topP !== undefined ) {
244
+ options . top_p = topP ;
119
245
}
120
- if ( input . config ?. topK !== undefined ) {
121
- options . top_k = input . config . topK ;
246
+ if ( topK !== undefined ) {
247
+ options . top_k = topK ;
122
248
}
123
- if ( input . config ?. stopSequences !== undefined ) {
124
- options . stop = input . config . stopSequences . join ( '' ) ;
249
+ if ( stopSequences !== undefined ) {
250
+ options . stop = stopSequences . join ( '' ) ;
125
251
}
126
- if ( input . config ?. maxOutputTokens !== undefined ) {
127
- options . num_predict = input . config . maxOutputTokens ;
252
+ if ( maxOutputTokens !== undefined ) {
253
+ options . num_predict = maxOutputTokens ;
128
254
}
129
255
const type = model . type ?? 'chat' ;
130
256
const request = toOllamaRequest (
@@ -136,18 +262,12 @@ function ollamaModel(
136
262
) ;
137
263
logger . debug ( request , `ollama request (${ type } )` ) ;
138
264
139
- const extraHeaders = requestHeaders
140
- ? typeof requestHeaders === 'function'
141
- ? await requestHeaders (
142
- {
143
- serverAddress,
144
- model,
145
- } ,
146
- input
147
- )
148
- : requestHeaders
149
- : { } ;
150
-
265
+ const extraHeaders = await getHeaders (
266
+ serverAddress ,
267
+ requestHeaders ,
268
+ model ,
269
+ input
270
+ ) ;
151
271
let res ;
152
272
try {
153
273
res = await fetch (
@@ -252,6 +372,25 @@ function parseMessage(response: any, type: ApiType): MessageData {
252
372
}
253
373
}
254
374
375
+ async function getHeaders (
376
+ serverAddress : string ,
377
+ requestHeaders ?: RequestHeaders ,
378
+ model ?: ModelDefinition ,
379
+ input ?: GenerateRequest
380
+ ) : Promise < Record < string , string > | void > {
381
+ return requestHeaders
382
+ ? typeof requestHeaders === 'function'
383
+ ? await requestHeaders (
384
+ {
385
+ serverAddress,
386
+ model,
387
+ } ,
388
+ input
389
+ )
390
+ : requestHeaders
391
+ : { } ;
392
+ }
393
+
255
394
function toOllamaRequest (
256
395
name : string ,
257
396
input : GenerateRequest ,
@@ -278,7 +417,13 @@ function toOllamaRequest(
278
417
messageText += c . text ;
279
418
}
280
419
if ( c . media ) {
281
- images . push ( c . media . url ) ;
420
+ let imageUri = c . media . url ;
421
+ // ollama doesn't accept full data URIs, just the base64 encoded image,
422
+ // strip out data URI prefix (ex. `data:image/jpeg;base64,`)
423
+ if ( imageUri . startsWith ( 'data:' ) ) {
424
+ imageUri = imageUri . substring ( imageUri . indexOf ( ',' ) + 1 ) ;
425
+ }
426
+ images . push ( imageUri ) ;
282
427
}
283
428
if ( c . toolRequest ) {
284
429
toolRequests . push ( c . toolRequest ) ;
@@ -391,3 +536,24 @@ function isValidOllamaTool(tool: ToolDefinition): boolean {
391
536
}
392
537
return true ;
393
538
}
539
+
540
+ export const ollama = ollamaPlugin as OllamaPlugin ;
541
+ ollama . model = (
542
+ name : string ,
543
+ config ?: z . infer < typeof OllamaConfigSchema >
544
+ ) : ModelReference < typeof OllamaConfigSchema > => {
545
+ return modelRef ( {
546
+ name : `ollama/${ name } ` ,
547
+ config,
548
+ configSchema : OllamaConfigSchema ,
549
+ } ) ;
550
+ } ;
551
+ ollama . embedder = (
552
+ name : string ,
553
+ config ?: Record < string , any >
554
+ ) : EmbedderReference => {
555
+ return embedderRef ( {
556
+ name : `ollama/${ name } ` ,
557
+ config,
558
+ } ) ;
559
+ } ;
0 commit comments