1
+ # Copyright 2025 Google LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pdb
5
+ from enum import StrEnum
6
+ from genkit .core .typing import (
7
+ ModelInfo ,
8
+ Supports ,
9
+ GenerationCommonConfig ,
10
+ GenerateRequest ,
11
+ GenerateResponse ,
12
+ Message ,
13
+ Role ,
14
+ TextPart ,
15
+ GenerateResponseChunk
16
+ )
17
+ from genkit .core .action import ActionRunContext
18
+ from openai import OpenAI as OpenAIClient
19
+ from google .auth import default , transport
20
+ from typing import Annotated
21
+
22
+ from pydantic import BaseModel , ConfigDict
23
+
24
+ class ChatMessage (BaseModel ):
25
+ model_config = ConfigDict (extra = 'forbid' , populate_by_name = True )
26
+
27
+ role : str
28
+ content : str
29
+
30
+
31
+ class OpenAIConfig (GenerationCommonConfig ):
32
+ """Config for OpenAI model."""
33
+ frequency_penalty : Annotated [float , range (- 2 , 2 )] | None = None
34
+ logit_bias : dict [str , Annotated [float , range (- 100 , 100 )]] | None = None
35
+ logprobs : bool | None = None
36
+ presence_penalty : Annotated [float , range (- 2 , 2 )] | None = None
37
+ seed : int | None = None
38
+ top_logprobs : Annotated [int , range (0 , 20 )] | None = None
39
+ user : str | None = None
40
+
41
+
42
+ class ChatCompletionRole (StrEnum ):
43
+ """Available roles supported by openai-compatible models."""
44
+ USER = 'user'
45
+ ASSISTANT = 'assistant'
46
+ SYSTEM = 'system'
47
+ TOOL = 'tool'
48
+
49
+
50
+ class OpenAICompatibleModel :
51
+ "Handles openai compatible model support in model_garden" ""
52
+
53
+ def __init__ (self , model : str , project_id : str , location : str ):
54
+ self ._model = model
55
+ self ._client = self .client_factory (location , project_id )
56
+
57
+ def client_factory (self , location : str , project_id : str ) -> OpenAIClient :
58
+ """Initiates an openai compatible client object and return it."""
59
+ if project_id :
60
+ credentials , _ = default ()
61
+ else :
62
+ credentials , project_id = default ()
63
+
64
+ credentials .refresh (transport .requests .Request ())
65
+ base_url = f'https://{ location } -aiplatform.googleapis.com/v1beta1/projects/{ project_id } /locations/{ location } /endpoints/openapi'
66
+ return OpenAIClient (api_key = credentials .token , base_url = base_url )
67
+
68
+
69
+ def to_openai_messages (self , messages : list [Message ]) -> list [ChatMessage ]:
70
+ if not messages :
71
+ raise ValueError ('No messages provided in the request.' )
72
+ return [
73
+ ChatMessage (
74
+ role = OpenAICompatibleModel .to_openai_role (m .role .value ),
75
+ content = '' .join (
76
+ part .root .text
77
+ for part in m .content
78
+ if part .root .text is not None
79
+ ),
80
+ )
81
+ for m in messages
82
+ ]
83
+ def generate (
84
+ self , request : GenerateRequest , ctx : ActionRunContext
85
+ ) -> GenerateResponse :
86
+ openai_config : dict = {
87
+ 'messages' : self .to_openai_messages (request .messages ),
88
+ 'model' : self ._model
89
+ }
90
+ if ctx .is_streaming :
91
+ openai_config ['stream' ] = True
92
+ stream = self ._client .chat .completions .create (** openai_config )
93
+ for chunk in stream :
94
+ choice = chunk .choices [0 ]
95
+ if not choice .delta .content :
96
+ continue
97
+
98
+ response_chunk = GenerateResponseChunk (
99
+ role = Role .MODEL ,
100
+ index = choice .index ,
101
+ content = [TextPart (text = choice .delta .content )],
102
+ )
103
+
104
+ ctx .send_chunk (response_chunk )
105
+
106
+ else :
107
+ response = self ._client .chat .completions .create (** openai_config )
108
+ return GenerateResponse (
109
+ request = request ,
110
+ message = Message (
111
+ role = Role .MODEL ,
112
+ content = [TextPart (text = response .choices [0 ].message .content )],
113
+ ),
114
+ )
115
+
116
+ @staticmethod
117
+ def to_openai_role (role : Role ) -> ChatCompletionRole :
118
+ """Converts Role enum to corrosponding OpenAI Compatible role."""
119
+ match role :
120
+ case Role .USER :
121
+ return ChatCompletionRole .USER
122
+ case Role .MODEL :
123
+ return ChatCompletionRole .ASSISTANT # "model" maps to "assistant"
124
+ case Role .SYSTEM :
125
+ return ChatCompletionRole .SYSTEM
126
+ case Role .TOOL :
127
+ return ChatCompletionRole .TOOL
128
+ case _:
129
+ raise ValueError (f"Role '{ role } ' doesn't map to an OpenAI role." )
130
+
131
+
132
+
133
+ class OllamaVersion (StrEnum ):
134
+ """Available versions of the llama model.
135
+
136
+ This enum defines the available versions of the llama model that
137
+ can be used through Vertex AI.
138
+ """
139
+ LLAMA_3_1 = 'llama-3.1'
140
+ LLAMA_3_2 = 'llama-3.2'
141
+ LLAMA3_405_B = 'llama3-405b'
142
+
143
+
144
+ SUPPORTED_MODELS = {
145
+ OllamaVersion .LLAMA_3_1 : ModelInfo (
146
+ versions = ['meta/llama3-405b-instruct-maas' ],
147
+ label = 'Llama 3.1' ,
148
+ supports = Supports (
149
+ multiturn = True , media = False , tools = True ,
150
+ systemRole = True , output = ['text' , 'json' ]
151
+ )
152
+ ),
153
+ OllamaVersion .LLAMA_3_2 : ModelInfo (
154
+ versions = ['meta/llama-3.2-90b-vision-instruct-maas' ],
155
+ label = 'Llama 3.2' ,
156
+ supports = Supports (
157
+ multiturn = True , media = True , tools = True ,
158
+ systemRole = True , output = ['text' , 'json' ]
159
+ )
160
+ ),
161
+ OllamaVersion .LLAMA3_405_B : ModelInfo (
162
+ versions = [],
163
+ label = 'Llama 3.1 405b' ,
164
+ supports = Supports (
165
+ multiturn = True , media = False , tools = True ,
166
+ systemRole = True , output = ['text' ]
167
+ )
168
+ )
169
+ }
0 commit comments