Skip to content

feat(py): Added ModelGarden plugin #2568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captainhook.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"run": "CaptainHook::File.BlockSecrets",
"options": {
"presets": ["Aws", "GitHub", "Stripe", "Google"],
"allowed": ["AIDAQEAAAAAAAAAAAABE"]
"allowed": ["AIDAQEAAAAAAAAA"]
}
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .handler import OpenAIModelHandler
from .model import OpenAIModel
from .model_info import SUPPORTED_OPENAI_MODELS
from .model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, PluginSource, get_default_model_info


def package_name() -> str:
Expand All @@ -28,7 +28,10 @@ def package_name() -> str:

__all__ = [
'OpenAIModel',
'PluginSource',
'SUPPORTED_OPENAI_MODELS',
'SUPPORTED_OPENAI_COMPAT_MODELS',
'OpenAIModelHandler',
'get_default_model_info',
'package_name',
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#
# SPDX-License-Identifier: Apache-2.0


"""OpenAI Compatible Model handlers for Genkit."""

from collections.abc import Callable
Expand All @@ -25,9 +24,10 @@
from genkit.ai import ActionRunContext, GenkitRegistry
from genkit.plugins.compat_oai.models.model import OpenAIModel
from genkit.plugins.compat_oai.models.model_info import (
SUPPORTED_OPENAI_COMPAT_MODELS,
SUPPORTED_OPENAI_MODELS,
PluginSource,
)
from genkit.plugins.compat_oai.typing import OpenAIConfig
from genkit.types import (
GenerateRequest,
GenerateResponse,
Expand All @@ -37,17 +37,34 @@
class OpenAIModelHandler:
"""Handles OpenAI API interactions for the Genkit plugin."""

def __init__(self, model: Any) -> None:
def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> None:
"""Initializes the OpenAIModelHandler with a specified model.

Args:
model: An instance of a Model subclass representing the OpenAI model.
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.
"""
self._model = model
self._source = source

@staticmethod
def _get_supported_models(source: PluginSource) -> dict[str, Any]:
"""Returns the supported models based on the plugin source.
Args:
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.

Returns:
Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.

"""

return SUPPORTED_OPENAI_COMPAT_MODELS if source == PluginSource.MODEL_GARDEN else SUPPORTED_OPENAI_MODELS

@classmethod
def get_model_handler(
cls, model: str, client: OpenAI, registry: GenkitRegistry
cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI
) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]:
"""Factory method to initialize the model handler for the specified OpenAI model.

Expand All @@ -61,18 +78,22 @@ def get_model_handler(
model: The OpenAI model name.
client: OpenAI client instance.
registry: Genkit registry instance.
source: Helps distinguish if model handler is called from model-garden plugin.
Default source is openai.

Returns:
A callable function that acts as an action handler.

Raises:
ValueError: If the specified model is not supported.
"""
if model not in SUPPORTED_OPENAI_MODELS:
supported_models = cls._get_supported_models(source)

if model not in supported_models:
raise ValueError(f"Model '{model}' is not supported.")

openai_model = OpenAIModel(model, client, registry)
return cls(openai_model).generate
return cls(openai_model, source).generate

def _validate_version(self, version: str) -> None:
"""Validates whether the specified model version is supported.
Expand All @@ -83,20 +104,11 @@ def _validate_version(self, version: str) -> None:
Raises:
ValueError: If the specified model version is not supported.
"""
model_info = SUPPORTED_OPENAI_MODELS[self._model.name]
supported_models = self._get_supported_models(self._source)
model_info = supported_models[self._model.name]
if version not in model_info.versions:
raise ValueError(f"Model version '{version}' is not supported.")

def _normalize_config(self, config: Any) -> OpenAIConfig:
"""Ensures the config is an OpenAIConfig instance."""
if isinstance(config, OpenAIConfig):
return config

if isinstance(config, dict):
return OpenAIConfig(**config)

raise ValueError(f'Expected request.config to be a dict or OpenAIConfig, got {type(config).__name__}.')

def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
"""Processes the request using OpenAI's chat completion API.

Expand All @@ -110,12 +122,9 @@ def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> GenerateR
Raises:
ValueError: If the specified model version is not supported.
"""
request.config = self._normalize_config(request.config)
request.config = self._model.normalize_config(request.config)

if request.config.model:
self._validate_version(request.config.model)

if ctx.is_streaming:
return self._model.generate_stream(request, ctx.send_chunk)
else:
return self._model.generate(request)
return self._model.generate(request, ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
"""OpenAI Compatible Models for Genkit."""

from collections.abc import Callable
from typing import Any

from openai import OpenAI, pydantic_function_tool
from openai.lib._pydantic import _ensure_strict_json_schema

from genkit.ai import ActionKind, GenkitRegistry
from genkit.core.action._action import ActionRunContext
from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS
from genkit.plugins.compat_oai.models.utils import DictMessageAdapter, MessageAdapter, MessageConverter
from genkit.plugins.compat_oai.typing import SupportedOutputFormat
from genkit.plugins.compat_oai.typing import OpenAIConfig, SupportedOutputFormat
from genkit.types import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -144,7 +146,7 @@ def _get_openai_request_config(self, request: GenerateRequest) -> dict:
openai_config.update(**request.config.model_dump(exclude_none=True))
return openai_config

def generate(self, request: GenerateRequest) -> GenerateResponse:
def _generate(self, request: GenerateRequest) -> GenerateResponse:
"""Processes the request using OpenAI's chat completion API and returns the generated response.

Args:
Expand All @@ -160,7 +162,7 @@ def generate(self, request: GenerateRequest) -> GenerateResponse:
message=MessageConverter.to_genkit(response.choices[0].message),
)

def generate_stream(self, request: GenerateRequest, callback: Callable) -> GenerateResponse:
def _generate_stream(self, request: GenerateRequest, callback: Callable) -> GenerateResponse:
"""Streams responses from the OpenAI client and sends chunks to a callback.

Args:
Expand Down Expand Up @@ -214,3 +216,31 @@ def generate_stream(self, request: GenerateRequest, callback: Callable) -> Gener
request=request,
message=Message(role=Role.MODEL, content=accumulated_content),
)

def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
"""Processes the request using OpenAI's chat completion API.

Args:
request: The request containing messages and configurations.
ctx: The context of the action run.

Returns:
A GenerateResponse containing the model's response.
"""
request.config = self.normalize_config(request.config)

if ctx.is_streaming:
return self._generate_stream(request, ctx.send_chunk)
else:
return self._generate(request)

@staticmethod
def normalize_config(config: Any) -> OpenAIConfig:
"""Ensures the config is an OpenAIConfig instance."""
if isinstance(config, OpenAIConfig):
return config

if isinstance(config, dict):
return OpenAIConfig(**config)

raise ValueError(f'Expected request.config to be a dict or OpenAIConfig, got {type(config).__name__}.')
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,37 @@

"""OpenAI Compatible Models for Genkit."""

import sys

if sys.version_info < (3, 11):
from strenum import StrEnum
else:
from enum import StrEnum

from genkit.plugins.compat_oai.typing import SupportedOutputFormat
from genkit.types import (
ModelInfo,
Supports,
)

OPENAI = 'openai'
MODEL_GARDEN = 'model-garden'


class PluginSource(StrEnum):
OPENAI = 'openai'
MODEL_GARDEN = 'model-garden'


GPT_3_5_TURBO = 'gpt-3.5-turbo'
GPT_4 = 'gpt-4'
GPT_4_TURBO = 'gpt-4-turbo'
GPT_4O = 'gpt-4o'
GPT_4O_MINI = 'gpt-4o-mini'
O1_MINI = 'o1-mini'

LLAMA_3_1 = 'meta/llama3-405b-instruct-maas'
LLAMA_3_2 = 'meta/llama-3.2-90b-vision-instruct-maas'

SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
GPT_3_5_TURBO: ModelInfo(
Expand Down Expand Up @@ -124,3 +142,43 @@
),
),
}

SUPPORTED_OPENAI_COMPAT_MODELS: dict[str, ModelInfo] = {
LLAMA_3_1: ModelInfo(
label='ModelGarden - Meta - llama-3.1',
supports=Supports(
multiturn=True,
media=False,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
LLAMA_3_2: ModelInfo(
label='ModelGarden - Meta - llama-3.2',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
}


DEFAULT_SUPPORTS = Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
)


def get_default_model_info(name: str) -> ModelInfo:
"""Gets the default model info given a name."""
return ModelInfo(
label=f'ModelGarden - {name}',
supports=DEFAULT_SUPPORTS,
)
33 changes: 0 additions & 33 deletions py/plugins/compat-oai/tests/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,3 @@ def test_validate_version() -> None:

with pytest.raises(ValueError, match="Model version 'invalid-version' is not supported."):
handler._validate_version('invalid-version')


def test_handler_generate_non_streaming(sample_request: GenerateRequest) -> None:
"""Test OpenAIModelHandler generate method in non-streaming mode."""
mock_model = MagicMock(spec=OpenAIModel)
mock_model.name = GPT_4
mock_model.generate.return_value = GenerateResponse(
message=Message(role=Role.MODEL, content=[TextPart(text='Hello, user!')])
)

handler = OpenAIModelHandler(mock_model)
mock_ctx = MagicMock(spec=ActionRunContext, is_streaming=False)

response = handler.generate(sample_request, mock_ctx)

mock_model.generate.assert_called_once()
assert isinstance(response, GenerateResponse)
assert response.message is not None
assert response.message.role == Role.MODEL
assert response.message.content[0].root.text == 'Hello, user!'


def test_handler_generate_streaming(sample_request: GenerateRequest) -> None:
"""Test OpenAIModelHandler generate method in streaming mode."""
mock_model = MagicMock(spec=OpenAIModel)
mock_model.name = GPT_4

handler = OpenAIModelHandler(mock_model)
mock_ctx = MagicMock(spec=ActionRunContext, is_streaming=True)

handler.generate(sample_request, mock_ctx)

mock_model.generate_stream.assert_called_once_with(sample_request, mock_ctx.send_chunk)
Loading
Loading