Skip to content

feat(py): Added ModelGarden plugin #2568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .handler import OpenAIModelHandler
from .model import OpenAIModel
from .model_info import SUPPORTED_OPENAI_MODELS
from .model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, PluginSource


def package_name() -> str:
Expand All @@ -28,7 +28,9 @@ def package_name() -> str:

__all__ = [
'OpenAIModel',
'PluginSource',
'SUPPORTED_OPENAI_MODELS',
'SUPPORTED_OPENAI_COMPAT_MODELS',
'OpenAIModelHandler',
'package_name',
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#
# SPDX-License-Identifier: Apache-2.0


"""OpenAI Compatible Model handlers for Genkit."""

from collections.abc import Callable
Expand All @@ -25,7 +24,9 @@
from genkit.ai import ActionRunContext, GenkitRegistry
from genkit.plugins.compat_oai.models.model import OpenAIModel
from genkit.plugins.compat_oai.models.model_info import (
SUPPORTED_OPENAI_COMPAT_MODELS,
SUPPORTED_OPENAI_MODELS,
PluginSource,
)
from genkit.plugins.compat_oai.typing import OpenAIConfig
from genkit.types import (
Expand All @@ -37,17 +38,34 @@
class OpenAIModelHandler:
"""Handles OpenAI API interactions for the Genkit plugin."""

def __init__(self, model: Any) -> None:
def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> None:
"""Initializes the OpenAIModelHandler with a specified model.

Args:
model: An instance of a Model subclass representing the OpenAI model.
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.
"""
self._model = model
self._source = source

@staticmethod
def _get_supported_models(source: PluginSource) -> dict[str, Any]:
"""Returns the supported models based on the plugin source.
Args:
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.

Returns:
Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.

"""

return SUPPORTED_OPENAI_COMPAT_MODELS if source == PluginSource.MODEL_GARDEN else SUPPORTED_OPENAI_MODELS

@classmethod
def get_model_handler(
cls, model: str, client: OpenAI, registry: GenkitRegistry
cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI
) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]:
"""Factory method to initialize the model handler for the specified OpenAI model.

Expand All @@ -61,18 +79,22 @@ def get_model_handler(
model: The OpenAI model name.
client: OpenAI client instance.
registry: Genkit registry instance.
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.

Returns:
A callable function that acts as an action handler.

Raises:
ValueError: If the specified model is not supported.
"""
if model not in SUPPORTED_OPENAI_MODELS:
supported_models = cls._get_supported_models(source)

if model not in supported_models:
raise ValueError(f"Model '{model}' is not supported.")

openai_model = OpenAIModel(model, client, registry)
return cls(openai_model).generate
return cls(openai_model, source).generate

def _validate_version(self, version: str) -> None:
"""Validates whether the specified model version is supported.
Expand All @@ -83,7 +105,8 @@ def _validate_version(self, version: str) -> None:
Raises:
ValueError: If the specified model version is not supported.
"""
model_info = SUPPORTED_OPENAI_MODELS[self._model.name]
supported_models = self._get_supported_models(self._source)
model_info = supported_models[self._model.name]
if version not in model_info.versions:
raise ValueError(f"Model version '{version}' is not supported.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,37 @@

"""OpenAI Compatible Models for Genkit."""

import sys

if sys.version_info < (3, 11):
from strenum import StrEnum
else:
from enum import StrEnum

from genkit.plugins.compat_oai.typing import SupportedOutputFormat
from genkit.types import (
ModelInfo,
Supports,
)

OPENAI = 'openai'
MODEL_GARDEN = 'model-garden'


class PluginSource(StrEnum):
OPENAI = 'openai'
MODEL_GARDEN = 'model-garden'


GPT_3_5_TURBO = 'gpt-3.5-turbo'
GPT_4 = 'gpt-4'
GPT_4_TURBO = 'gpt-4-turbo'
GPT_4O = 'gpt-4o'
GPT_4O_MINI = 'gpt-4o-mini'
O1_MINI = 'o1-mini'

LLAMA_3_1 = 'llama-3.1'
LLAMA_3_2 = 'llama-3.2'

SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
GPT_3_5_TURBO: ModelInfo(
Expand Down Expand Up @@ -124,3 +142,28 @@
),
),
}

SUPPORTED_OPENAI_COMPAT_MODELS: dict[str, ModelInfo] = {
LLAMA_3_1: ModelInfo(
versions=['meta/llama3-405b-instruct-maas'],
label='llama-3.1',
supports=Supports(
multiturn=True,
media=False,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
LLAMA_3_2: ModelInfo(
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
label='llama-3.2',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from .model_garden import model_garden_name
from .modelgarden_plugin import VertexAIModelGarden

__all__ = [model_garden_name, VertexAIModelGarden]
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from google.auth import default, transport
from openai import OpenAI as _OpenAI


class OpenAIClient:
"""Handles OpenAI API client Initialization."""

def __new__(cls, **openai_params) -> _OpenAI:
"""Initializes the OpenAIClient based on the plugin source."""
location = openai_params.get('location')
project_id = openai_params.get('project_id')
if project_id:
credentials, _ = default()
else:
credentials, project_id = default()

credentials.refresh(transport.requests.Request())
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
return _OpenAI(api_key=credentials.token, base_url=base_url)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from genkit.ai import GenkitRegistry
from genkit.plugins.compat_oai.models import (
SUPPORTED_OPENAI_COMPAT_MODELS,
OpenAIModelHandler,
PluginSource,
)
from genkit.plugins.compat_oai.models.model_info import PluginSource
from genkit.plugins.compat_oai.typing import OpenAIConfig

from .client import OpenAIClient

OPENAI_COMPAT = 'openai-compat'


def model_garden_name(name: str) -> str:
"""Create a Model Garden action name.

Args:
name: Base name for the action.

Returns:
The fully qualified Model Garden action name.
"""
return f'modelgarden/{name}'


class ModelGarden:
@staticmethod
def get_model_info(name: str) -> dict[str, str] | None:
"""Returns model type and name for a given model.

Args:
name: Name of the model for which type and name are to be returned

"""
if SUPPORTED_OPENAI_COMPAT_MODELS.get(name):
return {'name': SUPPORTED_OPENAI_COMPAT_MODELS.get(name).label, 'type': OPENAI_COMPAT}

@classmethod
def to_openai_compatible_model(cls, ai: GenkitRegistry, model: str, location: str, project_id: str):
if model not in SUPPORTED_OPENAI_COMPAT_MODELS:
raise ValueError(f"Model '{model}' is not supported.")
openai_params = {'location': location, 'project_id': project_id}
openai_client = OpenAIClient(**openai_params)
handler = OpenAIModelHandler.get_model_handler(
model=model, client=openai_client, registry=ai, source=PluginSource.MODEL_GARDEN
)

supports = SUPPORTED_OPENAI_COMPAT_MODELS[model].supports.model_dump()

ai.define_model(
name=model_garden_name(model),
fn=handler,
config_schema=OpenAIConfig,
metadata={'model': {'supports': supports}},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""ModelGarden API Compatible Plugin for Genkit."""

import os

from genkit.ai import GenkitRegistry, Plugin
from genkit.plugins.vertex_ai import constants as const

from .model_garden import OPENAI_COMPAT, ModelGarden


class VertexAIModelGarden(Plugin):
"""Model Garden plugin for Genkit.

This plugin provides integration with Google Cloud's Vertex AI platform,
enabling the use of Vertex AI models and services within the Genkit
framework. It handles initialization of the Model Garden client and
registration of model actions.
"""

name = 'modelgarden'

def __init__(self, project_id: str | None = None, location: str | None = None, models: list[str] | None = None):
"""Initialize the plugin by registering actions with the registry."""
self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT)
self.location = location if location is not None else const.DEFAULT_REGION
self.models = models

def initialize(self, ai: GenkitRegistry) -> None:
"""Handles actions for various openaicompatible models."""
models = self.models
if models:
for model in models:
model_info = ModelGarden.get_model_info(model)
if model_info:
if model_info['type'] == OPENAI_COMPAT:
ModelGarden.to_openai_compatible_model(
ai,
model=model_info['name'],
location=self.location,
project_id=self.project_id,
)
Loading