14
14
#
15
15
# SPDX-License-Identifier: Apache-2.0
16
16
17
-
18
17
"""OpenAI Compatible Model handlers for Genkit."""
19
18
20
19
from collections .abc import Callable
25
24
from genkit .ai import ActionRunContext , GenkitRegistry
26
25
from genkit .plugins .compat_oai .models .model import OpenAIModel
27
26
from genkit .plugins .compat_oai .models .model_info import (
27
+ SUPPORTED_OPENAI_COMPAT_MODELS ,
28
28
SUPPORTED_OPENAI_MODELS ,
29
+ PluginSource ,
29
30
)
30
31
from genkit .plugins .compat_oai .typing import OpenAIConfig
31
32
from genkit .types import (
37
38
class OpenAIModelHandler :
38
39
"""Handles OpenAI API interactions for the Genkit plugin."""
39
40
40
- def __init__ (self , model : Any ) -> None :
41
+ def __init__ (self , model : Any , source : PluginSource = PluginSource . OPENAI ) -> None :
41
42
"""Initializes the OpenAIModelHandler with a specified model.
42
43
43
44
Args:
44
45
model: An instance of a Model subclass representing the OpenAI model.
46
+ source: Helps distinguish if model handler is called from model-garden plugin.
47
+ Default source is openai.
45
48
"""
46
49
self ._model = model
50
+ self ._source = source
51
+
52
+ @staticmethod
53
+ def _get_supported_models (source : PluginSource ) -> dict [str , Any ]:
54
+ """Returns the supported models based on the plugin source.
55
+ Args:
56
+ source: Helps distinguish if model handler is called from model-garden plugin.
57
+ Default source is openai.
58
+
59
+ Returns:
60
+ Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden.
61
+
62
+ """
63
+ if source == PluginSource .MODEL_GARDEN :
64
+ return {** SUPPORTED_OPENAI_MODELS , ** SUPPORTED_OPENAI_COMPAT_MODELS }
65
+ return SUPPORTED_OPENAI_MODELS
47
66
48
67
@classmethod
49
68
def get_model_handler (
50
- cls , model : str , client : OpenAI , registry : GenkitRegistry
69
+ cls , model : str , client : OpenAI , registry : GenkitRegistry , source : PluginSource = PluginSource . OPENAI
51
70
) -> Callable [[GenerateRequest , ActionRunContext ], GenerateResponse ]:
52
71
"""Factory method to initialize the model handler for the specified OpenAI model.
53
72
@@ -61,18 +80,22 @@ def get_model_handler(
61
80
model: The OpenAI model name.
62
81
client: OpenAI client instance.
63
82
registry: Genkit registry instance.
83
+ source: Helps distinguish if model handler is called from model-garden plugin.
84
+ Default source is openai.
64
85
65
86
Returns:
66
87
A callable function that acts as an action handler.
67
88
68
89
Raises:
69
90
ValueError: If the specified model is not supported.
70
91
"""
71
- if model not in SUPPORTED_OPENAI_MODELS :
92
+ supported_models = cls ._get_supported_models (source )
93
+
94
+ if model not in supported_models :
72
95
raise ValueError (f"Model '{ model } ' is not supported." )
73
96
74
97
openai_model = OpenAIModel (model , client , registry )
75
- return cls (openai_model ).generate
98
+ return cls (openai_model , source ).generate
76
99
77
100
def _validate_version (self , version : str ) -> None :
78
101
"""Validates whether the specified model version is supported.
@@ -83,7 +106,8 @@ def _validate_version(self, version: str) -> None:
83
106
Raises:
84
107
ValueError: If the specified model version is not supported.
85
108
"""
86
- model_info = SUPPORTED_OPENAI_MODELS [self ._model .name ]
109
+ supported_models = self ._get_supported_models (self ._source )
110
+ model_info = supported_models [self ._model .name ]
87
111
if version not in model_info .versions :
88
112
raise ValueError (f"Model version '{ version } ' is not supported." )
89
113
0 commit comments