14
14
* limitations under the License.
15
15
*/
16
16
17
- import { Genkit , ToolRequest , ToolRequestPart , ToolResponse , z } from 'genkit' ;
17
+ import {
18
+ ActionMetadata ,
19
+ Genkit ,
20
+ ToolRequest ,
21
+ ToolRequestPart ,
22
+ ToolResponse ,
23
+ z ,
24
+ } from 'genkit' ;
18
25
import { logger } from 'genkit/logging' ;
19
26
import {
20
27
GenerateRequest ,
21
28
GenerateResponseData ,
22
29
GenerationCommonConfigDescriptions ,
23
30
GenerationCommonConfigSchema ,
24
31
MessageData ,
32
+ ModelInfo ,
25
33
ToolDefinition ,
26
34
getBasicUsageStats ,
27
35
} from 'genkit/model' ;
@@ -43,21 +51,83 @@ const ANY_JSON_SCHEMA: Record<string, any> = {
43
51
$schema : 'http://json-schema.org/draft-07/schema#' ,
44
52
} ;
45
53
54
+ const GENERIC_MODEL_INFO = {
55
+ supports : {
56
+ multiturn : true ,
57
+ media : true ,
58
+ tools : true ,
59
+ toolChoice : true ,
60
+ systemRole : true ,
61
+ constrained : 'all' ,
62
+ } ,
63
+ } as ModelInfo ;
64
+
46
65
export function ollama ( params : OllamaPluginParams ) : GenkitPlugin {
47
- return genkitPlugin ( 'ollama' , async ( ai : Genkit ) => {
48
- const serverAddress = params . serverAddress ;
49
- params . models ?. map ( ( model ) =>
50
- ollamaModel ( ai , model , serverAddress , params . requestHeaders )
51
- ) ;
52
- params . embedders ?. map ( ( model ) =>
53
- defineOllamaEmbedder ( ai , {
54
- name : model . name ,
55
- modelName : model . name ,
56
- dimensions : model . dimensions ,
57
- options : params ,
58
- } )
59
- ) ;
60
- } ) ;
66
+ const serverAddress = params . serverAddress ;
67
+ return genkitPlugin (
68
+ 'ollama' ,
69
+ async ( ai : Genkit ) => {
70
+ params . models ?. map ( ( model ) =>
71
+ ollamaModel ( ai , model , serverAddress , params . requestHeaders )
72
+ ) ;
73
+ params . embedders ?. map ( ( model ) =>
74
+ defineOllamaEmbedder ( ai , {
75
+ name : model . name ,
76
+ modelName : model . name ,
77
+ dimensions : model . dimensions ,
78
+ options : params ,
79
+ } )
80
+ ) ;
81
+ } ,
82
+ async ( ai , actionType , actionName ) => {
83
+ // We can only dynamically resolve models, for embedders user must provide dimensions.
84
+ if ( actionType === 'model' ) {
85
+ ollamaModel (
86
+ ai ,
87
+ {
88
+ name : actionName ,
89
+ } ,
90
+ serverAddress ,
91
+ params . requestHeaders
92
+ ) ;
93
+ }
94
+ } ,
95
+ async ( ) => {
96
+ // We call the ollama list local models api: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
97
+ let res ;
98
+ try {
99
+ res = await fetch ( serverAddress + '/api/tags' , {
100
+ method : 'GET' ,
101
+ headers : {
102
+ 'Content-Type' : 'application/json' ,
103
+ ...( await getHeaders ( serverAddress , params . requestHeaders ) ) ,
104
+ } ,
105
+ } ) ;
106
+ } catch ( e ) {
107
+ throw new Error ( `Make sure the Ollama server is running.` , {
108
+ cause : e ,
109
+ } ) ;
110
+ }
111
+ const modelResponse = JSON . parse ( await res . text ( ) ) ;
112
+ return (
113
+ modelResponse ?. models
114
+ // naively filter out embedders, unfortunately there's no better way.
115
+ ?. filter ( ( m ) => m . model && ! m . model . includes ( 'embed' ) )
116
+ . map (
117
+ ( m ) =>
118
+ ( {
119
+ actionType : 'model' ,
120
+ name : `ollama/${ m . model } ` ,
121
+ metadata : {
122
+ model : {
123
+ ...GENERIC_MODEL_INFO ,
124
+ } as ModelInfo ,
125
+ } ,
126
+ } ) as ActionMetadata
127
+ ) || [ ]
128
+ ) ;
129
+ }
130
+ ) ;
61
131
}
62
132
63
133
/**
@@ -110,21 +180,29 @@ function ollamaModel(
110
180
} ,
111
181
} ,
112
182
async ( input , streamingCallback ) => {
113
- const options : Record < string , any > = { } ;
114
- if ( input . config ?. temperature !== undefined ) {
115
- options . temperature = input . config . temperature ;
183
+ const {
184
+ temperature,
185
+ topP,
186
+ topK,
187
+ stopSequences,
188
+ maxOutputTokens,
189
+ ...rest
190
+ } = input . config as any ;
191
+ const options : Record < string , any > = { ...rest } ;
192
+ if ( temperature !== undefined ) {
193
+ options . temperature = temperature ;
116
194
}
117
- if ( input . config ?. topP !== undefined ) {
118
- options . top_p = input . config . topP ;
195
+ if ( topP !== undefined ) {
196
+ options . top_p = topP ;
119
197
}
120
- if ( input . config ?. topK !== undefined ) {
121
- options . top_k = input . config . topK ;
198
+ if ( topK !== undefined ) {
199
+ options . top_k = topK ;
122
200
}
123
- if ( input . config ?. stopSequences !== undefined ) {
124
- options . stop = input . config . stopSequences . join ( '' ) ;
201
+ if ( stopSequences !== undefined ) {
202
+ options . stop = stopSequences . join ( '' ) ;
125
203
}
126
- if ( input . config ?. maxOutputTokens !== undefined ) {
127
- options . num_predict = input . config . maxOutputTokens ;
204
+ if ( maxOutputTokens !== undefined ) {
205
+ options . num_predict = maxOutputTokens ;
128
206
}
129
207
const type = model . type ?? 'chat' ;
130
208
const request = toOllamaRequest (
@@ -136,18 +214,12 @@ function ollamaModel(
136
214
) ;
137
215
logger . debug ( request , `ollama request (${ type } )` ) ;
138
216
139
- const extraHeaders = requestHeaders
140
- ? typeof requestHeaders === 'function'
141
- ? await requestHeaders (
142
- {
143
- serverAddress,
144
- model,
145
- } ,
146
- input
147
- )
148
- : requestHeaders
149
- : { } ;
150
-
217
+ const extraHeaders = await getHeaders (
218
+ serverAddress ,
219
+ requestHeaders ,
220
+ model ,
221
+ input
222
+ ) ;
151
223
let res ;
152
224
try {
153
225
res = await fetch (
@@ -252,6 +324,25 @@ function parseMessage(response: any, type: ApiType): MessageData {
252
324
}
253
325
}
254
326
327
+ async function getHeaders (
328
+ serverAddress : string ,
329
+ requestHeaders ?: RequestHeaders ,
330
+ model ?: ModelDefinition ,
331
+ input ?: GenerateRequest
332
+ ) : Promise < Record < string , string > | void > {
333
+ return requestHeaders
334
+ ? typeof requestHeaders === 'function'
335
+ ? await requestHeaders (
336
+ {
337
+ serverAddress,
338
+ model,
339
+ } ,
340
+ input
341
+ )
342
+ : requestHeaders
343
+ : { } ;
344
+ }
345
+
255
346
function toOllamaRequest (
256
347
name : string ,
257
348
input : GenerateRequest ,
@@ -278,7 +369,13 @@ function toOllamaRequest(
278
369
messageText += c . text ;
279
370
}
280
371
if ( c . media ) {
281
- images . push ( c . media . url ) ;
372
+ let imageUri = c . media . url ;
373
+ // ollama doesn't accept full data URIs, just the base64 encoded image,
374
+ // strip out data URI prefix (ex. `data:image/jpeg;base64,`)
375
+ if ( imageUri . startsWith ( 'data:' ) ) {
376
+ imageUri = imageUri . substring ( imageUri . indexOf ( ',' ) + 1 ) ;
377
+ }
378
+ images . push ( imageUri ) ;
282
379
}
283
380
if ( c . toolRequest ) {
284
381
toolRequests . push ( c . toolRequest ) ;
0 commit comments