1
1
# Copyright 2025 Google LLC
2
2
# SPDX-License-Identifier: Apache-2.0
3
- import os
4
- from enum import StrEnum
5
-
6
- from genkit .ai .registry import GenkitRegistry
7
- from genkit .core .typing import (
8
- ModelInfo ,
9
- Supports ,
3
+ from genkit .ai import GenkitRegistry
4
+ from genkit .plugins .compat_oai .models import (
5
+ SUPPORTED_OPENAI_COMPAT_MODELS ,
6
+ OpenAIModelHandler ,
7
+ PluginSource ,
10
8
)
9
+ from genkit .plugins .compat_oai .openai_client_handler import PluginSource
11
10
from genkit .plugins .compat_oai .typing import OpenAIConfig
12
- from genkit .plugins .vertex_ai import constants as const
13
11
14
- from .openai_compatibility import OpenAICompatibleModel
12
+ from .client import OpenAIClient
13
+
14
+ OPENAI_COMPAT = 'openai-compat'
15
15
16
16
17
17
def vertexai_name (name : str ) -> str :
@@ -26,43 +26,33 @@ def vertexai_name(name: str) -> str:
26
26
return f'vertexai/{ name } '
27
27
28
28
29
- class OpenAIFormatModelVersion (StrEnum ):
30
- """Available versions of the llama model.
31
-
32
- This enum defines the available versions of the llama model that
33
- can be used through Vertex AI.
34
- """
35
-
36
- LLAMA_3_1 = 'llama-3.1'
37
- LLAMA_3_2 = 'llama-3.2'
38
-
29
+ class ModelGarden :
30
+ @staticmethod
31
+ def get_model_info (name : str ) -> dict [str , str ] | None :
32
+ """Returns model type and name for a given model.
39
33
40
- SUPPORTED_OPENAI_FORMAT_MODELS : dict [str , ModelInfo ] = {
41
- OpenAIFormatModelVersion .LLAMA_3_1 : ModelInfo (
42
- versions = ['meta/llama3-405b-instruct-maas' ],
43
- label = 'llama-3.1' ,
44
- supports = Supports (multiturn = True , media = False , tools = True , systemRole = True , output = ['text' , 'json' ]),
45
- ),
46
- OpenAIFormatModelVersion .LLAMA_3_2 : ModelInfo (
47
- versions = ['meta/llama-3.2-90b-vision-instruct-maas' ],
48
- label = 'llama-3.2' ,
49
- supports = Supports (multiturn = True , media = True , tools = True , systemRole = True , output = ['text' , 'json' ]),
50
- ),
51
- }
34
+ Args:
35
+ name: Name of the model for which type and name are to be returned
52
36
37
+ """
38
+ if SUPPORTED_OPENAI_COMPAT_MODELS .get (name ):
39
+ return {'name' : SUPPORTED_OPENAI_COMPAT_MODELS .get (name ).label , 'type' : OPENAI_COMPAT }
53
40
54
- class ModelGarden :
55
41
@classmethod
56
42
def to_openai_compatible_model (cls , ai : GenkitRegistry , model : str , location : str , project_id : str ):
57
- if model not in SUPPORTED_OPENAI_FORMAT_MODELS :
43
+ if model not in SUPPORTED_OPENAI_COMPAT_MODELS :
58
44
raise ValueError (f"Model '{ model } ' is not supported." )
59
- model_version = SUPPORTED_OPENAI_FORMAT_MODELS [model ].versions [0 ]
60
- open_ai_compat = OpenAICompatibleModel (model_version , project_id , location )
61
- supports = SUPPORTED_OPENAI_FORMAT_MODELS [model ].supports .model_dump ()
45
+ openai_params = {'location' : location , 'project_id' : project_id }
46
+ openai_client = OpenAIClient (** openai_params )
47
+ handler = OpenAIModelHandler .get_model_handler (
48
+ model = model , client = openai_client , registry = ai , source = PluginSource .MODEL_GARDEN
49
+ )
50
+
51
+ supports = SUPPORTED_OPENAI_COMPAT_MODELS [model ].supports .model_dump ()
62
52
63
53
ai .define_model (
64
54
name = f'vertexai/{ model } ' ,
65
- fn = open_ai_compat . generate ,
55
+ fn = handler ,
66
56
config_schema = OpenAIConfig ,
67
57
metadata = {'model' : {'supports' : supports }},
68
58
)
0 commit comments