Skip to content

Commit 91ca7ae

Browse files
committed
feat: Added ModelGarden plugin
1 parent 6361eb9 commit 91ca7ae

File tree

14 files changed

+1468
-730
lines changed

14 files changed

+1468
-730
lines changed

py/packages/genkit/src/genkit/core/typing.py.backup

+925
Large diffs are not rendered by default.

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .handler import OpenAIModelHandler
2121
from .model import OpenAIModel
22-
from .model_info import SUPPORTED_OPENAI_MODELS
22+
from .model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, PluginSource
2323

2424

2525
def package_name() -> str:
@@ -28,7 +28,9 @@ def package_name() -> str:
2828

2929
__all__ = [
3030
'OpenAIModel',
31+
'PluginSource',
3132
'SUPPORTED_OPENAI_MODELS',
33+
'SUPPORTED_OPENAI_COMPAT_MODELS',
3234
'OpenAIModelHandler',
3335
'package_name',
3436
]

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from genkit.ai import ActionRunContext, GenkitRegistry
2626
from genkit.plugins.compat_oai.models.model import OpenAIModel
2727
from genkit.plugins.compat_oai.models.model_info import (
28+
SUPPORTED_OPENAI_COMPAT_MODELS,
2829
SUPPORTED_OPENAI_MODELS,
30+
PluginSource,
2931
)
3032
from genkit.plugins.compat_oai.typing import OpenAIConfig
3133
from genkit.types import (
@@ -37,17 +39,35 @@
3739
class OpenAIModelHandler:
3840
"""Handles OpenAI API interactions for the Genkit plugin."""
3941

40-
def __init__(self, model: Any) -> None:
42+
def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> None:
4143
"""Initializes the OpenAIModelHandler with a specified model.
4244
4345
Args:
4446
model: An instance of a Model subclass representing the OpenAI model.
47+
source: Helps distinguish if model handler is called from model-garden plugin.
48+
Default source is openai.
4549
"""
4650
self._model = model
51+
self._source = source
52+
53+
@staticmethod
54+
def _get_supported_models(source: PluginSource) -> dict[str, Any]:
55+
"""Returns the supported models based on the plugin source.
56+
Args:
57+
source: Helps distinguish if model handler is called from model-garden plugin.
58+
Default source is openai.
59+
60+
Returns:
61+
Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
62+
63+
"""
64+
if source == PluginSource.MODEL_GARDEN:
65+
return {**SUPPORTED_OPENAI_MODELS, **SUPPORTED_OPENAI_COMPAT_MODELS}
66+
return SUPPORTED_OPENAI_MODELS
4767

4868
@classmethod
4969
def get_model_handler(
50-
cls, model: str, client: OpenAI, registry: GenkitRegistry
70+
cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI
5171
) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]:
5272
"""Factory method to initialize the model handler for the specified OpenAI model.
5373
@@ -61,18 +81,22 @@ def get_model_handler(
6181
model: The OpenAI model name.
6282
client: OpenAI client instance.
6383
registry: Genkit registry instance.
84+
source: Helps distinguish if model handler is called from model-garden plugin.
85+
Default source is openai.
6486
6587
Returns:
6688
A callable function that acts as an action handler.
6789
6890
Raises:
6991
ValueError: If the specified model is not supported.
7092
"""
71-
if model not in SUPPORTED_OPENAI_MODELS:
93+
supported_models = cls._get_supported_models(source)
94+
95+
if model not in supported_models:
7296
raise ValueError(f"Model '{model}' is not supported.")
7397

7498
openai_model = OpenAIModel(model, client, registry)
75-
return cls(openai_model).generate
99+
return cls(openai_model, source).generate
76100

77101
def _validate_version(self, version: str) -> None:
78102
"""Validates whether the specified model version is supported.
@@ -83,7 +107,8 @@ def _validate_version(self, version: str) -> None:
83107
Raises:
84108
ValueError: If the specified model version is not supported.
85109
"""
86-
model_info = SUPPORTED_OPENAI_MODELS[self._model.name]
110+
supported_models = self._get_supported_models(self._source)
111+
model_info = supported_models[self._model.name]
87112
if version not in model_info.versions:
88113
raise ValueError(f"Model version '{version}' is not supported.")
89114

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py

+38
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,32 @@
1717

1818
"""OpenAI Compatible Models for Genkit."""
1919

20+
from enum import StrEnum
21+
2022
from genkit.plugins.compat_oai.typing import SupportedOutputFormat
2123
from genkit.types import (
2224
ModelInfo,
2325
Supports,
2426
)
2527

28+
OPENAI = 'openai'
29+
MODEL_GARDEN = 'model-garden'
30+
31+
32+
class PluginSource(StrEnum):
33+
OPENAI = 'openai'
34+
MODEL_GARDEN = 'model-garden'
35+
36+
2637
GPT_3_5_TURBO = 'gpt-3.5-turbo'
2738
GPT_4 = 'gpt-4'
2839
GPT_4_TURBO = 'gpt-4-turbo'
2940
GPT_4O = 'gpt-4o'
3041
GPT_4O_MINI = 'gpt-4o-mini'
3142
O1_MINI = 'o1-mini'
3243

44+
LLAMA_3_1 = 'llama-3.1'
45+
LLAMA_3_2 = 'llama-3.2'
3346

3447
SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
3548
GPT_3_5_TURBO: ModelInfo(
@@ -124,3 +137,28 @@
124137
),
125138
),
126139
}
140+
141+
SUPPORTED_OPENAI_COMPAT_MODELS: dict[str, ModelInfo] = {
142+
LLAMA_3_1: ModelInfo(
143+
versions=['meta/llama3-405b-instruct-maas'],
144+
label='llama-3.1',
145+
supports=Supports(
146+
multiturn=True,
147+
media=False,
148+
tools=True,
149+
systemRole=True,
150+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
151+
),
152+
),
153+
LLAMA_3_2: ModelInfo(
154+
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
155+
label='llama-3.2',
156+
supports=Supports(
157+
multiturn=True,
158+
media=True,
159+
tools=True,
160+
systemRole=True,
161+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
162+
),
163+
),
164+
}

py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SUPPORTED_OPENAI_MODELS,
2525
OpenAIModelHandler,
2626
)
27+
from genkit.plugins.compat_oai.openai_client_handler import OpenAIClientHandler
2728
from genkit.plugins.compat_oai.typing import OpenAIConfig
2829

2930

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
from .model_garden import vertexai_name
18+
from .modelgarden_plugin import VertexAIModelGarden
19+
20+
__all__ = [vertexai_name, VertexAIModelGarden.__name__]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
from google.auth import default, transport
18+
from openai import OpenAI as _OpenAI
19+
20+
21+
class OpenAIClient:
22+
"""Handles OpenAI API client Initialization."""
23+
24+
def __new__(cls, **openai_params) -> _OpenAI:
25+
"""Initializes the OpenAIClient based on the plugin source."""
26+
location = openai_params.get('location')
27+
project_id = openai_params.get('project_id')
28+
if project_id:
29+
credentials, _ = default()
30+
else:
31+
credentials, project_id = default()
32+
33+
credentials.refresh(transport.requests.Request())
34+
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
35+
return _OpenAI(api_key=credentials.token, base_url=base_url)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
from genkit.ai import GenkitRegistry
4+
from genkit.plugins.compat_oai.models import (
5+
SUPPORTED_OPENAI_COMPAT_MODELS,
6+
OpenAIModelHandler,
7+
PluginSource,
8+
)
9+
from genkit.plugins.compat_oai.openai_client_handler import PluginSource
10+
from genkit.plugins.compat_oai.typing import OpenAIConfig
11+
12+
from .client import OpenAIClient
13+
14+
OPENAI_COMPAT = 'openai-compat'
15+
16+
17+
def vertexai_name(name: str) -> str:
18+
"""Create a Vertex AI action name.
19+
20+
Args:
21+
name: Base name for the action.
22+
23+
Returns:
24+
The fully qualified Vertex AI action name.
25+
"""
26+
return f'vertexai/{name}'
27+
28+
29+
class ModelGarden:
30+
@staticmethod
31+
def get_model_info(name: str) -> dict[str, str] | None:
32+
"""Returns model type and name for a given model.
33+
34+
Args:
35+
name: Name of the model for which type and name are to be returned
36+
37+
"""
38+
if SUPPORTED_OPENAI_COMPAT_MODELS.get(name):
39+
return {'name': SUPPORTED_OPENAI_COMPAT_MODELS.get(name).label, 'type': OPENAI_COMPAT}
40+
41+
@classmethod
42+
def to_openai_compatible_model(cls, ai: GenkitRegistry, model: str, location: str, project_id: str):
43+
if model not in SUPPORTED_OPENAI_COMPAT_MODELS:
44+
raise ValueError(f"Model '{model}' is not supported.")
45+
openai_params = {'location': location, 'project_id': project_id}
46+
openai_client = OpenAIClient(**openai_params)
47+
handler = OpenAIModelHandler.get_model_handler(
48+
model=model, client=openai_client, registry=ai, source=PluginSource.MODEL_GARDEN
49+
)
50+
51+
supports = SUPPORTED_OPENAI_COMPAT_MODELS[model].supports.model_dump()
52+
53+
ai.define_model(
54+
name=f'vertexai/{model}',
55+
fn=handler,
56+
config_schema=OpenAIConfig,
57+
metadata={'model': {'supports': supports}},
58+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
"""ModelGarden API Compatible Plugin for Genkit."""
6+
7+
import os
8+
9+
from genkit.ai import GenkitRegistry, Plugin
10+
from genkit.plugins.vertex_ai import constants as const
11+
12+
from .model_garden import OPENAI_COMPAT, ModelGarden
13+
14+
15+
class VertexAIModelGarden(Plugin):
16+
"""Model Garden plugin for Genkit.
17+
18+
This plugin provides integration with Google Cloud's Vertex AI platform,
19+
enabling the use of Vertex AI models and services within the Genkit
20+
framework. It handles initialization of the Model Garden client and
21+
registration of model actions.
22+
"""
23+
24+
name = 'vertex-ai-model-garden'
25+
26+
def __init__(self, project_id: str | None = None, location: str | None = None, models: list[str] | None = None):
27+
"""Initialize the plugin by registering actions with the registry."""
28+
self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT)
29+
self.location = location if location is not None else const.DEFAULT_REGION
30+
self.models = models
31+
32+
def initialize(self, ai: GenkitRegistry) -> None:
33+
"""Handles actions for various openaicompatible models."""
34+
models = self.models
35+
if models:
36+
for model in models:
37+
model_info = ModelGarden.get_model_info(model)
38+
if model_info:
39+
if model_info['type'] == OPENAI_COMPAT:
40+
ModelGarden.to_openai_compatible_model(
41+
ai,
42+
model=model_info['name'],
43+
location=self.location,
44+
project_id=self.project_id,
45+
)

0 commit comments

Comments
 (0)