Skip to content

Commit 4131b28

Browse files
committed
feat: Modelgarden uses openai plugin
1 parent 348396e commit 4131b28

File tree

11 files changed

+137
-1017
lines changed

11 files changed

+137
-1017
lines changed

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .handler import OpenAIModelHandler
2121
from .model import OpenAIModel
22-
from .model_info import SUPPORTED_OPENAI_MODELS
22+
from .model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, PluginSource
2323

2424

2525
def package_name() -> str:
@@ -28,7 +28,9 @@ def package_name() -> str:
2828

2929
__all__ = [
3030
'OpenAIModel',
31+
'PluginSource',
3132
'SUPPORTED_OPENAI_MODELS',
33+
'SUPPORTED_OPENAI_COMPAT_MODELS',
3234
'OpenAIModelHandler',
3335
'package_name',
3436
]

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from genkit.ai import ActionRunContext, GenkitRegistry
2626
from genkit.plugins.compat_oai.models.model import OpenAIModel
2727
from genkit.plugins.compat_oai.models.model_info import (
28+
SUPPORTED_OPENAI_COMPAT_MODELS,
2829
SUPPORTED_OPENAI_MODELS,
29-
SUPPORTED_OPENAI_COMPAT_MODELS
30+
PluginSource,
3031
)
3132
from genkit.plugins.compat_oai.typing import OpenAIConfig
3233
from genkit.types import (
@@ -38,17 +39,35 @@
3839
class OpenAIModelHandler:
3940
"""Handles OpenAI API interactions for the Genkit plugin."""
4041

41-
def __init__(self, model: Any) -> None:
42+
def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> None:
4243
"""Initializes the OpenAIModelHandler with a specified model.
4344
4445
Args:
4546
model: An instance of a Model subclass representing the OpenAI model.
47+
source: Helps distinguish if model handler is called from model-garden plugin.
48+
Default source is openai.
4649
"""
4750
self._model = model
51+
self._source = source
52+
53+
@staticmethod
54+
def _get_supported_models(source: PluginSource) -> dict[str, Any]:
55+
"""Returns the supported models based on the plugin source.
56+
Args:
57+
source: Helps distinguish if model handler is called from model-garden plugin.
58+
Default source is openai.
59+
60+
Returns:
61+
Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
62+
63+
"""
64+
if source == PluginSource.MODEL_GARDEN:
65+
return {**SUPPORTED_OPENAI_MODELS, **SUPPORTED_OPENAI_COMPAT_MODELS}
66+
return SUPPORTED_OPENAI_MODELS
4867

4968
@classmethod
5069
def get_model_handler(
51-
cls, model: str, client: OpenAI, registry: GenkitRegistry
70+
cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI
5271
) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]:
5372
"""Factory method to initialize the model handler for the specified OpenAI model.
5473
@@ -62,18 +81,22 @@ def get_model_handler(
6281
model: The OpenAI model name.
6382
client: OpenAI client instance.
6483
registry: Genkit registry instance.
84+
source: Helps distinguish if model handler is called from model-garden plugin.
85+
Default source is openai.
6586
6687
Returns:
6788
A callable function that acts as an action handler.
6889
6990
Raises:
7091
ValueError: If the specified model is not supported.
7192
"""
72-
if model not in SUPPORTED_OPENAI_MODELS:
93+
supported_models = cls._get_supported_models(source)
94+
95+
if model not in supported_models:
7396
raise ValueError(f"Model '{model}' is not supported.")
7497

7598
openai_model = OpenAIModel(model, client, registry)
76-
return cls(openai_model).generate
99+
return cls(openai_model, source).generate
77100

78101
def _validate_version(self, version: str) -> None:
79102
"""Validates whether the specified model version is supported.
@@ -84,7 +107,8 @@ def _validate_version(self, version: str) -> None:
84107
Raises:
85108
ValueError: If the specified model version is not supported.
86109
"""
87-
model_info = SUPPORTED_OPENAI_MODELS[self._model.name]
110+
supported_models = self._get_supported_models(self._source)
111+
model_info = supported_models[self._model.name]
88112
if version not in model_info.versions:
89113
raise ValueError(f"Model version '{version}' is not supported.")
90114

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py

+13-27
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,23 @@
1717

1818
"""OpenAI Compatible Models for Genkit."""
1919

20+
from enum import StrEnum
21+
2022
from genkit.plugins.compat_oai.typing import SupportedOutputFormat
2123
from genkit.types import (
2224
ModelInfo,
2325
Supports,
2426
)
2527

28+
OPENAI = 'openai'
29+
MODEL_GARDEN = 'model-garden'
30+
31+
32+
class PluginSource(StrEnum):
33+
OPENAI = 'openai'
34+
MODEL_GARDEN = 'model-garden'
35+
36+
2637
GPT_3_5_TURBO = 'gpt-3.5-turbo'
2738
GPT_4 = 'gpt-4'
2839
GPT_4_TURBO = 'gpt-4-turbo'
@@ -136,32 +147,7 @@
136147
media=False,
137148
tools=True,
138149
systemRole=True,
139-
output=['text', 'json']
140-
),
141-
),
142-
LLAMA_3_2: ModelInfo(
143-
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
144-
label='llama-3.2',
145-
supports=Supports(
146-
multiturn=True,
147-
media=True,
148-
tools=True,
149-
systemRole=True,
150-
output=['text', 'json']
151-
),
152-
),
153-
}
154-
155-
SUPPORTED_OPENAI_COMPAT_MODELS: dict[str, ModelInfo] = {
156-
LLAMA_3_1: ModelInfo(
157-
versions=['meta/llama3-405b-instruct-maas'],
158-
label='llama-3.1',
159-
supports=Supports(
160-
multiturn=True,
161-
media=False,
162-
tools=True,
163-
systemRole=True,
164-
output=['text', 'json']
150+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
165151
),
166152
),
167153
LLAMA_3_2: ModelInfo(
@@ -172,7 +158,7 @@
172158
media=True,
173159
tools=True,
174160
systemRole=True,
175-
output=['text', 'json']
161+
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
176162
),
177163
),
178164
}

py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
SUPPORTED_OPENAI_MODELS,
2525
OpenAIModelHandler,
2626
)
27-
from genkit.plugins.compat_oai.typing import OpenAIConfig
2827
from genkit.plugins.compat_oai.openai_client_handler import OpenAIClientHandler
28+
from genkit.plugins.compat_oai.typing import OpenAIConfig
2929

3030

3131
class OpenAI(Plugin):

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17-
from .model_garden import OpenAIFormatModelVersion, vertexai_name
17+
from .model_garden import vertexai_name
1818
from .modelgarden_plugin import VertexAIModelGarden
1919

20-
__all__ = [OpenAIFormatModelVersion.__name__, vertexai_name, VertexAIModelGarden.__name__]
20+
__all__ = [vertexai_name, VertexAIModelGarden.__name__]

py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_client_handler.py renamed to py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/client.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17-
from openai import OpenAI as OpenAIClient
17+
from google.auth import default, transport
18+
from openai import OpenAI as _OpenAI
1819

1920

20-
class OpenAIClientHandler:
21+
class OpenAIClient:
2122
"""Handles OpenAI API client Initialization."""
2223

23-
@staticmethod
24-
def initialize_client(plugin_source: Optional[str] = None, **openai_params):
24+
def __new__(cls, **openai_params) -> _OpenAI:
2525
"""Initializes the OpenAIClient based on the plugin source."""
26+
location = openai_params.get('location')
27+
project_id = openai_params.get('project_id')
28+
if project_id:
29+
credentials, _ = default()
30+
else:
31+
credentials, project_id = default()
32+
33+
credentials.refresh(transport.requests.Request())
34+
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
35+
return _OpenAI(api_key=credentials.token, base_url=base_url)
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# Copyright 2025 Google LLC
22
# SPDX-License-Identifier: Apache-2.0
3-
import os
4-
from enum import StrEnum
5-
6-
from genkit.ai.registry import GenkitRegistry
7-
from genkit.core.typing import (
8-
ModelInfo,
9-
Supports,
3+
from genkit.ai import GenkitRegistry
4+
from genkit.plugins.compat_oai.models import (
5+
SUPPORTED_OPENAI_COMPAT_MODELS,
6+
OpenAIModelHandler,
7+
PluginSource,
108
)
9+
from genkit.plugins.compat_oai.openai_client_handler import PluginSource
1110
from genkit.plugins.compat_oai.typing import OpenAIConfig
12-
from genkit.plugins.vertex_ai import constants as const
1311

14-
from .openai_compatibility import OpenAICompatibleModel
12+
from .client import OpenAIClient
13+
14+
OPENAI_COMPAT = 'openai-compat'
1515

1616

1717
def vertexai_name(name: str) -> str:
@@ -26,43 +26,33 @@ def vertexai_name(name: str) -> str:
2626
return f'vertexai/{name}'
2727

2828

29-
class OpenAIFormatModelVersion(StrEnum):
30-
"""Available versions of the llama model.
31-
32-
This enum defines the available versions of the llama model that
33-
can be used through Vertex AI.
34-
"""
35-
36-
LLAMA_3_1 = 'llama-3.1'
37-
LLAMA_3_2 = 'llama-3.2'
38-
29+
class ModelGarden:
30+
@staticmethod
31+
def get_model_info(name: str) -> dict[str, str] | None:
32+
"""Returns model type and name for a given model.
3933
40-
SUPPORTED_OPENAI_FORMAT_MODELS: dict[str, ModelInfo] = {
41-
OpenAIFormatModelVersion.LLAMA_3_1: ModelInfo(
42-
versions=['meta/llama3-405b-instruct-maas'],
43-
label='llama-3.1',
44-
supports=Supports(multiturn=True, media=False, tools=True, systemRole=True, output=['text', 'json']),
45-
),
46-
OpenAIFormatModelVersion.LLAMA_3_2: ModelInfo(
47-
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
48-
label='llama-3.2',
49-
supports=Supports(multiturn=True, media=True, tools=True, systemRole=True, output=['text', 'json']),
50-
),
51-
}
34+
Args:
35+
name: Name of the model for which type and name are to be returned
5236
37+
"""
38+
if SUPPORTED_OPENAI_COMPAT_MODELS.get(name):
39+
return {'name': SUPPORTED_OPENAI_COMPAT_MODELS.get(name).label, 'type': OPENAI_COMPAT}
5340

54-
class ModelGarden:
5541
@classmethod
5642
def to_openai_compatible_model(cls, ai: GenkitRegistry, model: str, location: str, project_id: str):
57-
if model not in SUPPORTED_OPENAI_FORMAT_MODELS:
43+
if model not in SUPPORTED_OPENAI_COMPAT_MODELS:
5844
raise ValueError(f"Model '{model}' is not supported.")
59-
model_version = SUPPORTED_OPENAI_FORMAT_MODELS[model].versions[0]
60-
open_ai_compat = OpenAICompatibleModel(model_version, project_id, location)
61-
supports = SUPPORTED_OPENAI_FORMAT_MODELS[model].supports.model_dump()
45+
openai_params = {'location': location, 'project_id': project_id}
46+
openai_client = OpenAIClient(**openai_params)
47+
handler = OpenAIModelHandler.get_model_handler(
48+
model=model, client=openai_client, registry=ai, source=PluginSource.MODEL_GARDEN
49+
)
50+
51+
supports = SUPPORTED_OPENAI_COMPAT_MODELS[model].supports.model_dump()
6252

6353
ai.define_model(
6454
name=f'vertexai/{model}',
65-
fn=open_ai_compat.generate,
55+
fn=handler,
6656
config_schema=OpenAIConfig,
6757
metadata={'model': {'supports': supports}},
6858
)

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py

+19-43
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,11 @@
55
"""ModelGarden API Compatible Plugin for Genkit."""
66

77
import os
8-
import pdb
9-
from pprint import pprint
108

11-
from pydantic import BaseModel, ConfigDict
12-
13-
from genkit.ai.plugin import Plugin
14-
from genkit.ai.registry import GenkitRegistry
9+
from genkit.ai import GenkitRegistry, Plugin
1510
from genkit.plugins.vertex_ai import constants as const
1611

17-
from .model_garden import SUPPORTED_OPENAI_FORMAT_MODELS, ModelGarden
18-
19-
20-
class CommonPluginOptions(BaseModel):
21-
model_config = ConfigDict(extra='forbid', populate_by_name=True)
22-
23-
project_id: str | None = None
24-
location: str | None = None
25-
models: list[str] | None = None
26-
27-
28-
def vertexai_name(name: str) -> str:
29-
"""Create a Vertex AI action name.
30-
31-
Args:
32-
name: Base name for the action.
33-
34-
Returns:
35-
The fully qualified Vertex AI action name.
36-
"""
37-
return f'vertexai/{name}'
12+
from .model_garden import OPENAI_COMPAT, ModelGarden
3813

3914

4015
class VertexAIModelGarden(Plugin):
@@ -46,24 +21,25 @@ class VertexAIModelGarden(Plugin):
4621
registration of model actions.
4722
"""
4823

49-
name = 'modelgarden'
24+
name = 'vertex-ai-model-garden'
5025

51-
def __init__(self, **kwargs):
26+
def __init__(self, project_id: str | None = None, location: str | None = None, models: list[str] | None = None):
5227
"""Initialize the plugin by registering actions with the registry."""
53-
self.plugin_options = CommonPluginOptions(
54-
project_id=kwargs.get('project_id', os.getenv(const.GCLOUD_PROJECT)),
55-
location=kwargs.get('location', const.DEFAULT_REGION),
56-
models=kwargs.get('models'),
57-
)
28+
self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT)
29+
self.location = location if location is not None else const.DEFAULT_REGION
30+
self.models = models
5831

5932
def initialize(self, ai: GenkitRegistry) -> None:
6033
"""Handles actions for various openaicompatible models."""
61-
for model in self.plugin_options.models:
62-
openai_model = SUPPORTED_OPENAI_FORMAT_MODELS.get(model).label
63-
if openai_model:
64-
ModelGarden.to_openai_compatible_model(
65-
ai,
66-
model=openai_model,
67-
location=self.plugin_options.location,
68-
project_id=self.plugin_options.project_id,
69-
)
34+
models = self.models
35+
if models:
36+
for model in models:
37+
model_info = ModelGarden.get_model_info(model)
38+
if model_info:
39+
if model_info['type'] == OPENAI_COMPAT:
40+
ModelGarden.to_openai_compatible_model(
41+
ai,
42+
model=model_info['name'],
43+
location=self.location,
44+
project_id=self.project_id,
45+
)

0 commit comments

Comments
 (0)