Skip to content

[Draft]: Vertex AI model garden with OpenAI Compatibility #2336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .model_garden import OpenAIFormatModelVersion, vertexai_name
from .modelgarden_plugin import VertexAIModelGarden

__all__ = [
OpenAIFormatModelVersion.__name__,
vertexai_name,
VertexAIModelGarden.__name__
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
import os

from enum import StrEnum


from genkit.plugins.vertex_ai import constants as const
from genkit.core.typing import (
ModelInfo,
Supports,
)
from genkit.veneer.registry import GenkitRegistry
from .openai_compatiblility import OpenAICompatibleModel, OpenAIConfig


def vertexai_name(name: str) -> str:
"""Create a Vertex AI action name.

Args:
name: Base name for the action.

Returns:
The fully qualified Vertex AI action name.
"""
return f'vertexai/{name}'

class OpenAIFormatModelVersion(StrEnum):
"""Available versions of the llama model.

This enum defines the available versions of the llama model that
can be used through Vertex AI.
"""
LLAMA_3_1 = 'llama-3.1'
LLAMA_3_2 = 'llama-3.2'


SUPPORTED_OPENAI_FORMAT_MODELS: dict[str, ModelInfo] = {
OpenAIFormatModelVersion.LLAMA_3_1: ModelInfo(
versions=['meta/llama3-405b-instruct-maas'],
label='llama-3.1',
supports=Supports(
multiturn=True,
media=False,
tools=True,
systemRole=True,
output=['text', 'json']
)
),
OpenAIFormatModelVersion.LLAMA_3_2: ModelInfo(
versions=['meta/llama-3.2-90b-vision-instruct-maas'],
label='llama-3.2',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True,
output=['text', 'json']
)
)
}


class ModelGarden:
@classmethod
def to_openai_compatible_model(
cls,
ai: GenkitRegistry,
model: str,
location: str,
project_id: str
):
if model not in SUPPORTED_OPENAI_FORMAT_MODELS:
raise ValueError(f"Model '{model}' is not supported.")
model_version = SUPPORTED_OPENAI_FORMAT_MODELS[model].versions[0]
open_ai_compat = OpenAICompatibleModel(
model_version,
project_id,
location
)
supports = SUPPORTED_OPENAI_FORMAT_MODELS[model].supports.model_dump()

ai.define_model(
name=f'vertexai/{model}',
fn=open_ai_compat.generate,
config_schema=OpenAIConfig,
metadata={
'model': {
'supports': supports
}
}

)


Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0


"""ModelGarden API Compatible Plugin for Genkit."""

from pydantic import BaseModel, ConfigDict
from .model_garden import (
SUPPORTED_OPENAI_FORMAT_MODELS,
ModelGarden
)
from genkit.veneer.plugin import Plugin
from genkit.veneer.registry import GenkitRegistry
from genkit.plugins.vertex_ai import constants as const
import pdb
import os
from pprint import pprint


class CommonPluginOptions(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

project_id: str | None = None
location: str | None = None
models: list[str] | None = None


def vertexai_name(name: str) -> str:
"""Create a Vertex AI action name.

Args:
name: Base name for the action.

Returns:
The fully qualified Vertex AI action name.
"""
return f'vertexai/{name}'

class VertexAIModelGarden(Plugin):
"""Model Garden plugin for Genkit.

This plugin provides integration with Google Cloud's Vertex AI platform,
enabling the use of Vertex AI models and services within the Genkit
framework. It handles initialization of the Model Garden client and
registration of model actions.
"""

name = "modelgarden"

def __init__(self, **kwargs):
"""Initialize the plugin by registering actions with the registry."""
self.plugin_options = CommonPluginOptions(
project_id=kwargs.get('project_id', os.getenv(const.GCLOUD_PROJECT)),
location=kwargs.get('location', const.DEFAULT_REGION),
models=kwargs.get('models')
)

def initialize(self, ai: GenkitRegistry) -> None:
"""Handles actions for various openaicompatible models."""
for model in self.plugin_options.models:
openai_model = next(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just use SUPPORTED_OPENAI_FORMAT_MODELS.get(model)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. modified it

(
key
for key, _ in SUPPORTED_OPENAI_FORMAT_MODELS.items()
if key == model
),
None
)
if openai_model:
ModelGarden.to_openai_compatible_model(
ai,
model=openai_model,
location=self.plugin_options.location,
project_id=self.plugin_options.project_id
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .openai_compatibility import OpenAICompatibleModel

__all__ = [OpenAICompatibleModel.__name__]
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import pdb
from enum import StrEnum
from genkit.core.typing import (
ModelInfo,
Supports,
GenerationCommonConfig,
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
GenerateResponseChunk
)
from genkit.core.action import ActionRunContext
from openai import OpenAI as OpenAIClient
from google.auth import default, transport
from typing import Annotated

from pydantic import BaseModel, ConfigDict

class ChatMessage(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

role: str
content: str


class OpenAIConfig(GenerationCommonConfig):
"""Config for OpenAI model."""
frequency_penalty: Annotated[float, range(-2, 2)] | None = None
logit_bias: dict[str, Annotated[float, range(-100, 100)]] | None = None
logprobs: bool | None = None
presence_penalty: Annotated[float, range(-2, 2)] | None = None
seed: int | None = None
top_logprobs: Annotated[int, range(0, 20)] | None = None
user: str | None = None


class ChatCompletionRole(StrEnum):
"""Available roles supported by openai-compatible models."""
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
TOOL = 'tool'


class OpenAICompatibleModel:
"Handles openai compatible model support in model_garden"""

def __init__(self, model: str, project_id: str, location: str):
self._model = model
self._client = self.client_factory(location, project_id)

def client_factory(self, location: str, project_id: str) -> OpenAIClient:
"""Initiates an openai compatible client object and return it."""
if project_id:
credentials, _ = default()
else:
credentials, project_id = default()

credentials.refresh(transport.requests.Request())
base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi'
return OpenAIClient(api_key=credentials.token, base_url=base_url)


def to_openai_messages(self, messages: list[Message]) -> list[ChatMessage]:
if not messages:
raise ValueError('No messages provided in the request.')
return [
ChatMessage(
role=OpenAICompatibleModel.to_openai_role(m.role.value),
content=''.join(
part.root.text
for part in m.content
if part.root.text is not None
),
)
for m in messages
]
def generate(
self, request: GenerateRequest, ctx: ActionRunContext
) -> GenerateResponse:
openai_config: dict = {
'messages': self.to_openai_messages(request.messages),
'model': self._model
}
if ctx.is_streaming:
openai_config['stream'] = True
stream = self._client.chat.completions.create(**openai_config)
for chunk in stream:
choice = chunk.choices[0]
if not choice.delta.content:
continue

response_chunk = GenerateResponseChunk(
role=Role.MODEL,
index=choice.index,
content=[TextPart(text=choice.delta.content)],
)

ctx.send_chunk(response_chunk)

else:
response = self._client.chat.completions.create(**openai_config)
return GenerateResponse(
request=request,
message=Message(
role=Role.MODEL,
content=[TextPart(text=response.choices[0].message.content)],
),
)

@staticmethod
def to_openai_role(role: Role) -> ChatCompletionRole:
"""Converts Role enum to corrosponding OpenAI Compatible role."""
match role:
case Role.USER:
return ChatCompletionRole.USER
case Role.MODEL:
return ChatCompletionRole.ASSISTANT # "model" maps to "assistant"
case Role.SYSTEM:
return ChatCompletionRole.SYSTEM
case Role.TOOL:
return ChatCompletionRole.TOOL
case _:
raise ValueError(f"Role '{role}' doesn't map to an OpenAI role.")
Loading
Loading