14
14
#
15
15
# SPDX-License-Identifier: Apache-2.0
16
16
17
-
18
17
"""OpenAI Compatible Model handlers for Genkit."""
19
18
20
19
from collections .abc import Callable
25
24
from genkit .ai import ActionRunContext , GenkitRegistry
26
25
from genkit .plugins .compat_oai .models .model import OpenAIModel
27
26
from genkit .plugins .compat_oai .models .model_info import (
27
+ SUPPORTED_OPENAI_COMPAT_MODELS ,
28
28
SUPPORTED_OPENAI_MODELS ,
29
+ PluginSource ,
29
30
)
30
31
from genkit .plugins .compat_oai .typing import OpenAIConfig
31
32
from genkit .types import (
37
38
class OpenAIModelHandler :
38
39
"""Handles OpenAI API interactions for the Genkit plugin."""
39
40
40
- def __init__ (self , model : Any ) -> None :
41
+ def __init__ (self , model : Any , source : PluginSource = PluginSource . OPENAI ) -> None :
41
42
"""Initializes the OpenAIModelHandler with a specified model.
42
43
43
44
Args:
44
45
model: An instance of a Model subclass representing the OpenAI model.
46
+ source: Helps distinguish if model handler is called from model-garden plugin.
47
+ Default source is openai.
45
48
"""
46
49
self ._model = model
50
+ self ._source = source
51
+
52
+ @staticmethod
53
+ def _get_supported_models (source : PluginSource ) -> dict [str , Any ]:
54
+ """Returns the supported models based on the plugin source.
55
+ Args:
56
+ source: Helps distinguish if model handler is called from model-garden plugin.
57
+ Default source is openai.
58
+
59
+ Returns:
60
+ Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
61
+
62
+ """
63
+
64
+ return SUPPORTED_OPENAI_COMPAT_MODELS if source == PluginSource .MODEL_GARDEN else SUPPORTED_OPENAI_MODELS
47
65
48
66
@classmethod
49
67
def get_model_handler (
50
- cls , model : str , client : OpenAI , registry : GenkitRegistry
68
+ cls , model : str , client : OpenAI , registry : GenkitRegistry , source : PluginSource = PluginSource . OPENAI
51
69
) -> Callable [[GenerateRequest , ActionRunContext ], GenerateResponse ]:
52
70
"""Factory method to initialize the model handler for the specified OpenAI model.
53
71
@@ -61,18 +79,22 @@ def get_model_handler(
61
79
model: The OpenAI model name.
62
80
client: OpenAI client instance.
63
81
registry: Genkit registry instance.
82
+ source: Helps distinguish if model handler is called from model-garden plugin.
83
+ Default source is openai.
64
84
65
85
Returns:
66
86
A callable function that acts as an action handler.
67
87
68
88
Raises:
69
89
ValueError: If the specified model is not supported.
70
90
"""
71
- if model not in SUPPORTED_OPENAI_MODELS :
91
+ supported_models = cls ._get_supported_models (source )
92
+
93
+ if model not in supported_models :
72
94
raise ValueError (f"Model '{ model } ' is not supported." )
73
95
74
96
openai_model = OpenAIModel (model , client , registry )
75
- return cls (openai_model ).generate
97
+ return cls (openai_model , source ).generate
76
98
77
99
def _validate_version (self , version : str ) -> None :
78
100
"""Validates whether the specified model version is supported.
@@ -83,7 +105,8 @@ def _validate_version(self, version: str) -> None:
83
105
Raises:
84
106
ValueError: If the specified model version is not supported.
85
107
"""
86
- model_info = SUPPORTED_OPENAI_MODELS [self ._model .name ]
108
+ supported_models = self ._get_supported_models (self ._source )
109
+ model_info = supported_models [self ._model .name ]
87
110
if version not in model_info .versions :
88
111
raise ValueError (f"Model version '{ version } ' is not supported." )
89
112
0 commit comments