Skip to content

Commit 4e9eadc

Browse files
committed
feat: Added ModelGarden plugin
1 parent 6361eb9 commit 4e9eadc

File tree

12 files changed

+606
-731
lines changed

12 files changed

+606
-731
lines changed

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .handler import OpenAIModelHandler
2121
from .model import OpenAIModel
22-
from .model_info import SUPPORTED_OPENAI_MODELS
22+
from .model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, PluginSource
2323

2424

2525
def package_name() -> str:
@@ -28,7 +28,9 @@ def package_name() -> str:
2828

2929
__all__ = [
3030
'OpenAIModel',
31+
'PluginSource',
3132
'SUPPORTED_OPENAI_MODELS',
33+
'SUPPORTED_OPENAI_COMPAT_MODELS',
3234
'OpenAIModelHandler',
3335
'package_name',
3436
]

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17-
1817
"""OpenAI Compatible Model handlers for Genkit."""
1918

2019
from collections.abc import Callable
@@ -25,7 +24,9 @@
2524
from genkit.ai import ActionRunContext, GenkitRegistry
2625
from genkit.plugins.compat_oai.models.model import OpenAIModel
2726
from genkit.plugins.compat_oai.models.model_info import (
27+
SUPPORTED_OPENAI_COMPAT_MODELS,
2828
SUPPORTED_OPENAI_MODELS,
29+
PluginSource,
2930
)
3031
from genkit.plugins.compat_oai.typing import OpenAIConfig
3132
from genkit.types import (
@@ -37,17 +38,35 @@
3738
class OpenAIModelHandler:
3839
"""Handles OpenAI API interactions for the Genkit plugin."""
3940

40-
def __init__(self, model: Any) -> None:
41+
def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> None:
4142
"""Initializes the OpenAIModelHandler with a specified model.
4243
4344
Args:
4445
model: An instance of a Model subclass representing the OpenAI model.
46+
source: Helps distinguish if model handler is called from model-garden plugin.
47+
Default source is openai.
4548
"""
4649
self._model = model
50+
self._source = source
51+
52+
@staticmethod
53+
def _get_supported_models(source: PluginSource) -> dict[str, Any]:
54+
"""Returns the supported models based on the plugin source.
55+
Args:
56+
source: Helps distinguish if model handler is called from model-garden plugin.
57+
Default source is openai.
58+
59+
Returns:
60+
Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
61+
62+
"""
63+
if source == PluginSource.MODEL_GARDEN:
64+
return {**SUPPORTED_OPENAI_MODELS, **SUPPORTED_OPENAI_COMPAT_MODELS}
65+
return SUPPORTED_OPENAI_MODELS
4766

4867
@classmethod
4968
def get_model_handler(
50-
cls, model: str, client: OpenAI, registry: GenkitRegistry
69+
cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI
5170
) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]:
5271
"""Factory method to initialize the model handler for the specified OpenAI model.
5372
@@ -61,18 +80,22 @@ def get_model_handler(
6180
model: The OpenAI model name.
6281
client: OpenAI client instance.
6382
registry: Genkit registry instance.
83+
source: Helps distinguish if model handler is called from model-garden plugin.
84+
Default source is openai.
6485
6586
Returns:
6687
A callable function that acts as an action handler.
6788
6889
Raises:
6990
ValueError: If the specified model is not supported.
7091
"""
71-
if model not in SUPPORTED_OPENAI_MODELS:
92+
supported_models = cls._get_supported_models(source)
93+
94+
if model not in supported_models:
7295
raise ValueError(f"Model '{model}' is not supported.")
7396

7497
openai_model = OpenAIModel(model, client, registry)
75-
return cls(openai_model).generate
98+
return cls(openai_model, source).generate
7699

77100
def _validate_version(self, version: str) -> None:
78101
"""Validates whether the specified model version is supported.
@@ -83,7 +106,8 @@ def _validate_version(self, version: str) -> None:
83106
Raises:
84107
ValueError: If the specified model version is not supported.
85108
"""
86-
model_info = SUPPORTED_OPENAI_MODELS[self._model.name]
109+
supported_models = self._get_supported_models(self._source)
110+
model_info = supported_models[self._model.name]
87111
if version not in model_info.versions:
88112
raise ValueError(f"Model version '{version}' is not supported.")
89113

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py

+38
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,32 @@
1717

1818
"""OpenAI Compatible Models for Genkit."""
1919

20+
from enum import StrEnum
21+
2022
from genkit.plugins.compat_oai.typing import SupportedOutputFormat
2123
from genkit.types import (
2224
ModelInfo,
2325
Supports,
2426
)
2527

28+
OPENAI = 'openai'
29+
MODEL_GARDEN = 'model-garden'
30+
31+
32+
class PluginSource(StrEnum):
33+
OPENAI = 'openai'
34+
MODEL_GARDEN = 'model-garden'
35+
36+
2637
GPT_3_5_TURBO = 'gpt-3.5-turbo'
2738
GPT_4 = 'gpt-4'
2839
GPT_4_TURBO = 'gpt-4-turbo'
2940
GPT_4O = 'gpt-4o'
3041
GPT_4O_MINI = 'gpt-4o-mini'
3142
O1_MINI = 'o1-mini'
3243

44+
LLAMA_3_1 = 'llama-3.1'
45+
LLAMA_3_2 = 'llama-3.2'
3346

3447
SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
3548
GPT_3_5_TURBO: ModelInfo(
@@ -124,3 +137,28 @@
124137
),
125138
),
126139
}
140+
141+
SUPPORTED_OPENAI_COMPAT_MODELS: dict[str, ModelInfo] = {
142+
LLAMA_3_1: ModelInfo(
143+
versions=['meta/llama3-405b-instruct-maas'],
144+
label='llama-3.1',
145+
supports=Supports(
146+
multiturn=True,
147+
media=False,
148+
tools=True,
149+
systemRole=True,
150+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
151+
),
152+
),
153+
LLAMA_3_2: ModelInfo(
154+
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
155+
label='llama-3.2',
156+
supports=Supports(
157+
multiturn=True,
158+
media=True,
159+
tools=True,
160+
systemRole=True,
161+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
162+
),
163+
),
164+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
from .model_garden import vertexai_name
18+
from .modelgarden_plugin import VertexAIModelGarden
19+
20+
__all__ = [vertexai_name, VertexAIModelGarden]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
from google.auth import default, transport
18+
from openai import OpenAI as _OpenAI
19+
20+
21+
class OpenAIClient:
22+
"""Handles OpenAI API client Initialization."""
23+
24+
def __new__(cls, **openai_params) -> _OpenAI:
25+
"""Initializes the OpenAIClient based on the plugin source."""
26+
location = openai_params.get('location')
27+
project_id = openai_params.get('project_id')
28+
if project_id:
29+
credentials, _ = default()
30+
else:
31+
credentials, project_id = default()
32+
33+
credentials.refresh(transport.requests.Request())
34+
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
35+
return _OpenAI(api_key=credentials.token, base_url=base_url)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
from genkit.ai import GenkitRegistry
18+
from genkit.plugins.compat_oai.models import (
19+
SUPPORTED_OPENAI_COMPAT_MODELS,
20+
OpenAIModelHandler,
21+
PluginSource,
22+
)
23+
from genkit.plugins.compat_oai.models.model_info import PluginSource
24+
from genkit.plugins.compat_oai.typing import OpenAIConfig
25+
26+
from .client import OpenAIClient
27+
28+
OPENAI_COMPAT = 'openai-compat'
29+
30+
31+
def vertexai_name(name: str) -> str:
32+
"""Create a Vertex AI action name.
33+
34+
Args:
35+
name: Base name for the action.
36+
37+
Returns:
38+
The fully qualified Vertex AI action name.
39+
"""
40+
return f'vertexai/{name}'
41+
42+
43+
class ModelGarden:
44+
@staticmethod
45+
def get_model_info(name: str) -> dict[str, str] | None:
46+
"""Returns model type and name for a given model.
47+
48+
Args:
49+
name: Name of the model for which type and name are to be returned
50+
51+
"""
52+
if SUPPORTED_OPENAI_COMPAT_MODELS.get(name):
53+
return {'name': SUPPORTED_OPENAI_COMPAT_MODELS.get(name).label, 'type': OPENAI_COMPAT}
54+
55+
@classmethod
56+
def to_openai_compatible_model(cls, ai: GenkitRegistry, model: str, location: str, project_id: str):
57+
if model not in SUPPORTED_OPENAI_COMPAT_MODELS:
58+
raise ValueError(f"Model '{model}' is not supported.")
59+
openai_params = {'location': location, 'project_id': project_id}
60+
openai_client = OpenAIClient(**openai_params)
61+
handler = OpenAIModelHandler.get_model_handler(
62+
model=model, client=openai_client, registry=ai, source=PluginSource.MODEL_GARDEN
63+
)
64+
65+
supports = SUPPORTED_OPENAI_COMPAT_MODELS[model].supports.model_dump()
66+
67+
ai.define_model(
68+
name=vertexai_name(model),
69+
fn=handler,
70+
config_schema=OpenAIConfig,
71+
metadata={'model': {'supports': supports}},
72+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
"""ModelGarden API Compatible Plugin for Genkit."""
18+
19+
import os
20+
21+
from genkit.ai import GenkitRegistry, Plugin
22+
from genkit.plugins.vertex_ai import constants as const
23+
24+
from .model_garden import OPENAI_COMPAT, ModelGarden
25+
26+
27+
class VertexAIModelGarden(Plugin):
28+
"""Model Garden plugin for Genkit.
29+
30+
This plugin provides integration with Google Cloud's Vertex AI platform,
31+
enabling the use of Vertex AI models and services within the Genkit
32+
framework. It handles initialization of the Model Garden client and
33+
registration of model actions.
34+
"""
35+
36+
name = 'vertex-ai-model-garden'
37+
38+
def __init__(self, project_id: str | None = None, location: str | None = None, models: list[str] | None = None):
39+
"""Initialize the plugin by registering actions with the registry."""
40+
self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT)
41+
self.location = location if location is not None else const.DEFAULT_REGION
42+
self.models = models
43+
44+
def initialize(self, ai: GenkitRegistry) -> None:
45+
"""Handles actions for various openaicompatible models."""
46+
models = self.models
47+
if models:
48+
for model in models:
49+
model_info = ModelGarden.get_model_info(model)
50+
if model_info:
51+
if model_info['type'] == OPENAI_COMPAT:
52+
ModelGarden.to_openai_compatible_model(
53+
ai,
54+
model=model_info['name'],
55+
location=self.location,
56+
project_id=self.project_id,
57+
)

0 commit comments

Comments
 (0)