Skip to content

Commit 7b9b92c

Browse files
author
Niraj Nepal
committed
Added ModelGarden Basic
1 parent cf2b7a9 commit 7b9b92c

File tree

9 files changed

+676
-0
lines changed

9 files changed

+676
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .model_garden import OpenAIFormatModelVersion, vertexai_name
2+
from .modelgarden_plugin import VertexAIModelGarden
3+
4+
__all__ = [
5+
OpenAIFormatModelVersion.__name__,
6+
vertexai_name,
7+
VertexAIModelGarden.__name__
8+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
import os
4+
5+
from enum import StrEnum
6+
7+
8+
from genkit.plugins.vertex_ai import constants as const
9+
from genkit.core.typing import (
10+
ModelInfo,
11+
Supports,
12+
)
13+
from genkit.veneer.registry import GenkitRegistry
14+
from .openai_compatiblility import OpenAICompatibleModel, OpenAIConfig
15+
16+
17+
def vertexai_name(name: str) -> str:
18+
"""Create a Vertex AI action name.
19+
20+
Args:
21+
name: Base name for the action.
22+
23+
Returns:
24+
The fully qualified Vertex AI action name.
25+
"""
26+
return f'vertexai/{name}'
27+
28+
class OpenAIFormatModelVersion(StrEnum):
29+
"""Available versions of the llama model.
30+
31+
This enum defines the available versions of the llama model that
32+
can be used through Vertex AI.
33+
"""
34+
LLAMA_3_1 = 'llama-3.1'
35+
LLAMA_3_2 = 'llama-3.2'
36+
37+
38+
SUPPORTED_OPENAI_FORMAT_MODELS: dict[str, ModelInfo] = {
39+
OpenAIFormatModelVersion.LLAMA_3_1: ModelInfo(
40+
versions=['meta/llama3-405b-instruct-maas'],
41+
label='llama-3.1',
42+
supports=Supports(
43+
multiturn=True,
44+
media=False,
45+
tools=True,
46+
systemRole=True,
47+
output=['text', 'json']
48+
)
49+
),
50+
OpenAIFormatModelVersion.LLAMA_3_2: ModelInfo(
51+
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
52+
label='llama-3.2',
53+
supports=Supports(
54+
multiturn=True,
55+
media=True,
56+
tools=True,
57+
systemRole=True,
58+
output=['text', 'json']
59+
)
60+
)
61+
}
62+
63+
64+
class ModelGarden:
65+
@classmethod
66+
def to_openai_compatible_model(
67+
cls,
68+
ai: GenkitRegistry,
69+
model: str,
70+
location: str,
71+
project_id: str
72+
):
73+
if model not in SUPPORTED_OPENAI_FORMAT_MODELS:
74+
raise ValueError(f"Model '{model}' is not supported.")
75+
model_version = SUPPORTED_OPENAI_FORMAT_MODELS[model].versions[0]
76+
open_ai_compat = OpenAICompatibleModel(
77+
model_version,
78+
project_id,
79+
location
80+
)
81+
supports = SUPPORTED_OPENAI_FORMAT_MODELS[model].supports.model_dump()
82+
83+
ai.define_model(
84+
name=f'vertexai/{model}',
85+
fn=open_ai_compat.generate,
86+
config_schema=OpenAIConfig,
87+
metadata={
88+
'model': {
89+
'supports': supports
90+
}
91+
}
92+
93+
)
94+
95+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
"""ModelGarden API Compatible Plugin for Genkit."""
6+
7+
from pydantic import BaseModel, ConfigDict
8+
from .model_garden import (
9+
SUPPORTED_OPENAI_FORMAT_MODELS,
10+
ModelGarden
11+
)
12+
from genkit.veneer.plugin import Plugin
13+
from genkit.veneer.registry import GenkitRegistry
14+
from genkit.plugins.vertex_ai import constants as const
15+
import pdb
16+
import os
17+
from pprint import pprint
18+
19+
20+
class CommonPluginOptions(BaseModel):
21+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
22+
23+
project_id: str | None = None
24+
location: str | None = None
25+
models: list[str] | None = None
26+
27+
28+
def vertexai_name(name: str) -> str:
29+
"""Create a Vertex AI action name.
30+
31+
Args:
32+
name: Base name for the action.
33+
34+
Returns:
35+
The fully qualified Vertex AI action name.
36+
"""
37+
return f'vertexai/{name}'
38+
39+
class VertexAIModelGarden(Plugin):
40+
"""Model Garden plugin for Genkit.
41+
42+
This plugin provides integration with Google Cloud's Vertex AI platform,
43+
enabling the use of Vertex AI models and services within the Genkit
44+
framework. It handles initialization of the Model Garden client and
45+
registration of model actions.
46+
"""
47+
48+
name = "modelgarden"
49+
50+
def __init__(self, **kwargs):
51+
"""Initialize the plugin by registering actions with the registry."""
52+
self.plugin_options = CommonPluginOptions(
53+
project_id=kwargs.get('project_id', os.getenv(const.GCLOUD_PROJECT)),
54+
location=kwargs.get('location', const.DEFAULT_REGION),
55+
models=kwargs.get('models')
56+
)
57+
58+
def initialize(self, ai: GenkitRegistry) -> None:
59+
"""Handles actions for various openaicompatible models."""
60+
for model in self.plugin_options.models:
61+
openai_model = next(
62+
(
63+
key
64+
for key, _ in SUPPORTED_OPENAI_FORMAT_MODELS.items()
65+
if key == model
66+
),
67+
None
68+
)
69+
if openai_model:
70+
ModelGarden.to_openai_compatible_model(
71+
ai,
72+
model=openai_model,
73+
location=self.plugin_options.location,
74+
project_id=self.plugin_options.project_id
75+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .openai_compatibility import OpenAICompatibleModel
2+
3+
__all__ = [OpenAICompatibleModel.__name__]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2025 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pdb
5+
from enum import StrEnum
6+
from genkit.core.typing import (
7+
ModelInfo,
8+
Supports,
9+
GenerationCommonConfig,
10+
GenerateRequest,
11+
GenerateResponse,
12+
Message,
13+
Role,
14+
TextPart,
15+
GenerateResponseChunk
16+
)
17+
from genkit.core.action import ActionRunContext
18+
from openai import OpenAI as OpenAIClient
19+
from google.auth import default, transport
20+
from typing import Annotated
21+
22+
from pydantic import BaseModel, ConfigDict
23+
24+
class ChatMessage(BaseModel):
25+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
26+
27+
role: str
28+
content: str
29+
30+
31+
class OpenAIConfig(GenerationCommonConfig):
32+
"""Config for OpenAI model."""
33+
frequency_penalty: Annotated[float, range(-2, 2)] | None = None
34+
logit_bias: dict[str, Annotated[float, range(-100, 100)]] | None = None
35+
logprobs: bool | None = None
36+
presence_penalty: Annotated[float, range(-2, 2)] | None = None
37+
seed: int | None = None
38+
top_logprobs: Annotated[int, range(0, 20)] | None = None
39+
user: str | None = None
40+
41+
42+
class ChatCompletionRole(StrEnum):
43+
"""Available roles supported by openai-compatible models."""
44+
USER = 'user'
45+
ASSISTANT = 'assistant'
46+
SYSTEM = 'system'
47+
TOOL = 'tool'
48+
49+
50+
class OpenAICompatibleModel:
51+
"Handles openai compatible model support in model_garden"""
52+
53+
def __init__(self, model: str, project_id: str, location: str):
54+
self._model = model
55+
self._client = self.client_factory(location, project_id)
56+
57+
def client_factory(self, location: str, project_id: str) -> OpenAIClient:
58+
"""Initiates an openai compatible client object and return it."""
59+
if project_id:
60+
credentials, _ = default()
61+
else:
62+
credentials, project_id = default()
63+
64+
credentials.refresh(transport.requests.Request())
65+
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
66+
return OpenAIClient(api_key=credentials.token, base_url=base_url)
67+
68+
69+
def to_openai_messages(self, messages: list[Message]) -> list[ChatMessage]:
70+
if not messages:
71+
raise ValueError('No messages provided in the request.')
72+
return [
73+
ChatMessage(
74+
role=OpenAICompatibleModel.to_openai_role(m.role.value),
75+
content=''.join(
76+
part.root.text
77+
for part in m.content
78+
if part.root.text is not None
79+
),
80+
)
81+
for m in messages
82+
]
83+
def generate(
84+
self, request: GenerateRequest, ctx: ActionRunContext
85+
) -> GenerateResponse:
86+
openai_config: dict = {
87+
'messages': self.to_openai_messages(request.messages),
88+
'model': self._model
89+
}
90+
if ctx.is_streaming:
91+
openai_config['stream'] = True
92+
stream = self._client.chat.completions.create(**openai_config)
93+
for chunk in stream:
94+
choice = chunk.choices[0]
95+
if not choice.delta.content:
96+
continue
97+
98+
response_chunk = GenerateResponseChunk(
99+
role=Role.MODEL,
100+
index=choice.index,
101+
content=[TextPart(text=choice.delta.content)],
102+
)
103+
104+
ctx.send_chunk(response_chunk)
105+
106+
else:
107+
response = self._client.chat.completions.create(**openai_config)
108+
return GenerateResponse(
109+
request=request,
110+
message=Message(
111+
role=Role.MODEL,
112+
content=[TextPart(text=response.choices[0].message.content)],
113+
),
114+
)
115+
116+
@staticmethod
117+
def to_openai_role(role: Role) -> ChatCompletionRole:
118+
"""Converts Role enum to corrosponding OpenAI Compatible role."""
119+
match role:
120+
case Role.USER:
121+
return ChatCompletionRole.USER
122+
case Role.MODEL:
123+
return ChatCompletionRole.ASSISTANT # "model" maps to "assistant"
124+
case Role.SYSTEM:
125+
return ChatCompletionRole.SYSTEM
126+
case Role.TOOL:
127+
return ChatCompletionRole.TOOL
128+
case _:
129+
raise ValueError(f"Role '{role}' doesn't map to an OpenAI role.")
130+
131+
132+
133+
class OllamaVersion(StrEnum):
134+
"""Available versions of the llama model.
135+
136+
This enum defines the available versions of the llama model that
137+
can be used through Vertex AI.
138+
"""
139+
LLAMA_3_1 = 'llama-3.1'
140+
LLAMA_3_2 = 'llama-3.2'
141+
LLAMA3_405_B = 'llama3-405b'
142+
143+
144+
SUPPORTED_MODELS = {
145+
OllamaVersion.LLAMA_3_1: ModelInfo(
146+
versions=['meta/llama3-405b-instruct-maas'],
147+
label='Llama 3.1',
148+
supports=Supports(
149+
multiturn=True, media=False, tools=True,
150+
systemRole=True, output=['text', 'json']
151+
)
152+
),
153+
OllamaVersion.LLAMA_3_2: ModelInfo(
154+
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
155+
label='Llama 3.2',
156+
supports=Supports(
157+
multiturn=True, media=True, tools=True,
158+
systemRole=True, output=['text', 'json']
159+
)
160+
),
161+
OllamaVersion.LLAMA3_405_B: ModelInfo(
162+
versions=[],
163+
label='Llama 3.1 405b',
164+
supports=Supports(
165+
multiturn=True, media=False, tools=True,
166+
systemRole=True, output=['text']
167+
)
168+
)
169+
}

0 commit comments

Comments
 (0)