25
25
from genkit .ai import ActionRunContext , GenkitRegistry
26
26
from genkit .plugins .compat_oai .models .model import OpenAIModel
27
27
from genkit .plugins .compat_oai .models .model_info import (
28
+ SUPPORTED_OPENAI_COMPAT_MODELS ,
28
29
SUPPORTED_OPENAI_MODELS ,
30
+ PluginSource ,
29
31
)
30
32
from genkit .plugins .compat_oai .typing import OpenAIConfig
31
33
from genkit .types import (
37
39
class OpenAIModelHandler :
38
40
"""Handles OpenAI API interactions for the Genkit plugin."""
39
41
40
- def __init__ (self , model : Any ) -> None :
42
+ def __init__ (self , model : Any , source : PluginSource = PluginSource . OPENAI ) -> None :
41
43
"""Initializes the OpenAIModelHandler with a specified model.
42
44
43
45
Args:
44
46
model: An instance of a Model subclass representing the OpenAI model.
47
+ source: Helps distinguish if model handler is called from model-garden plugin.
48
+ Default source is openai.
45
49
"""
46
50
self ._model = model
51
+ self ._source = source
52
+
53
+ @staticmethod
54
+ def _get_supported_models (source : PluginSource ) -> dict [str , Any ]:
55
+ """Returns the supported models based on the plugin source.
56
+ Args:
57
+ source: Helps distinguish if model handler is called from model-garden plugin.
58
+ Default source is openai.
59
+
60
+ Returns:
61
+ Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
62
+
63
+ """
64
+ if source == PluginSource .MODEL_GARDEN :
65
+ return {** SUPPORTED_OPENAI_MODELS , ** SUPPORTED_OPENAI_COMPAT_MODELS }
66
+ return SUPPORTED_OPENAI_MODELS
47
67
48
68
@classmethod
49
69
def get_model_handler (
50
- cls , model : str , client : OpenAI , registry : GenkitRegistry
70
+ cls , model : str , client : OpenAI , registry : GenkitRegistry , source : PluginSource = PluginSource . OPENAI
51
71
) -> Callable [[GenerateRequest , ActionRunContext ], GenerateResponse ]:
52
72
"""Factory method to initialize the model handler for the specified OpenAI model.
53
73
@@ -61,18 +81,22 @@ def get_model_handler(
61
81
model: The OpenAI model name.
62
82
client: OpenAI client instance.
63
83
registry: Genkit registry instance.
84
+ source: Helps distinguish if model handler is called from model-garden plugin.
85
+ Default source is openai.
64
86
65
87
Returns:
66
88
A callable function that acts as an action handler.
67
89
68
90
Raises:
69
91
ValueError: If the specified model is not supported.
70
92
"""
71
- if model not in SUPPORTED_OPENAI_MODELS :
93
+ supported_models = cls ._get_supported_models (source )
94
+
95
+ if model not in supported_models :
72
96
raise ValueError (f"Model '{ model } ' is not supported." )
73
97
74
98
openai_model = OpenAIModel (model , client , registry )
75
- return cls (openai_model ).generate
99
+ return cls (openai_model , source ).generate
76
100
77
101
def _validate_version (self , version : str ) -> None :
78
102
"""Validates whether the specified model version is supported.
@@ -83,7 +107,8 @@ def _validate_version(self, version: str) -> None:
83
107
Raises:
84
108
ValueError: If the specified model version is not supported.
85
109
"""
86
- model_info = SUPPORTED_OPENAI_MODELS [self ._model .name ]
110
+ supported_models = self ._get_supported_models (self ._source )
111
+ model_info = supported_models [self ._model .name ]
87
112
if version not in model_info .versions :
88
113
raise ValueError (f"Model version '{ version } ' is not supported." )
89
114
0 commit comments