Skip to content

Commit 89c0744

Browse files
author
Teryl Taylor
committed
feat(plugins): integrated plugins into prompt service, fixed linting and type issues.
Signed-off-by: Teryl Taylor <[email protected]>
1 parent a91eee0 commit 89c0744

File tree

13 files changed

+137
-37
lines changed

13 files changed

+137
-37
lines changed

.env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,8 @@ DEBUG=false
285285
# Gateway tool name separator
286286
GATEWAY_TOOL_NAME_SEPARATOR=-
287287
VALID_SLUG_SEPARATOR_REGEXP= r"^(-{1,2}|[_.])$"
288+
289+
#####################################
290+
# Plugins Settings
291+
#####################################
292+
PLUGINS_ENABLED=false

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ recursive-include alembic *.md
5555
recursive-include alembic *.py
5656
# recursive-include deployment *
5757
# recursive-include mcp-servers *
58+
recursive-include plugins *.py
59+
recursive-include plugins *.yaml
5860

5961
# 5️⃣ (Optional) include MKDocs-based docs in the sdist
6062
# graft docs

mcpgateway/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def _parse_federation_peers(cls, v):
309309
use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store
310310
json_response_enabled: bool = True # Enable JSON responses instead of SSE streams
311311

312+
# Core plugin settings
313+
plugins_enabled: bool = Field(default=False, description="Enable the plugin framework")
314+
plugin_config_file: str = Field(default="plugins/config.yaml", description="Path to main plugin configuration file")
315+
312316
# Development
313317
dev_mode: bool = False
314318
reload: bool = False
@@ -495,9 +499,7 @@ def validate_database(self) -> None:
495499
db_dir.mkdir(parents=True)
496500

497501
# Validation patterns for safe display (configurable)
498-
validation_dangerous_html_pattern: str = (
499-
r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>"
500-
)
502+
validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>"
501503

502504
validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)"
503505

mcpgateway/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
ResourceContent,
7474
Root,
7575
)
76+
from mcpgateway.plugins.framework.manager import PluginManager
7677
from mcpgateway.schemas import (
7778
GatewayCreate,
7879
GatewayRead,
@@ -158,6 +159,8 @@
158159
else:
159160
loop.create_task(bootstrap_db())
160161

162+
# Initialize plugin manager as a singleton.
163+
plugin_manager: PluginManager | None = PluginManager(settings.plugin_config_file) if settings.plugins_enabled else None
161164

162165
# Initialize services
163166
tool_service = ToolService()
@@ -212,6 +215,9 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
212215
"""
213216
logger.info("Starting MCP Gateway services")
214217
try:
218+
if plugin_manager:
219+
await plugin_manager.initialize()
220+
logger.info(f"Plugin manager initialized with {plugin_manager.plugin_count} plugins")
215221
await tool_service.initialize()
216222
await resource_service.initialize()
217223
await prompt_service.initialize()
@@ -230,6 +236,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
230236
logger.error(f"Error during startup: {str(e)}")
231237
raise
232238
finally:
239+
# Shutdown plugin manager
240+
if plugin_manager:
241+
try:
242+
await plugin_manager.shutdown()
243+
logger.info("Plugin manager shutdown complete")
244+
except Exception as e:
245+
logger.error(f"Error shutting down plugin manager: {str(e)}")
233246
logger.info("Shutting down MCP Gateway services")
234247
# await stop_streamablehttp()
235248
for service in [resource_cache, sampling_handler, logging_service, completion_service, root_service, gateway_service, prompt_service, resource_service, tool_service, streamable_http_session]:

mcpgateway/plugins/framework/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi
123123
of plugin type {type(self)}
124124
""")
125125

126-
def shutdown(self) -> None:
126+
async def shutdown(self) -> None:
127127
"""Plugin cleanup code."""
128128

129129

mcpgateway/plugins/framework/loader/plugin.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,32 @@ def __init__(self) -> None:
2828
self._plugin_types: dict[str, Type[Plugin]] = {}
2929

3030
def __get_plugin_type(self, kind: str) -> Type[Plugin]:
31+
"""Import a plugin type from a python module.
32+
33+
Args:
34+
kind: The fully-qualified type of the plugin to be registered.
35+
36+
Raises:
37+
Exception: if unable to import a module.
38+
39+
Returns:
40+
A plugin type.
41+
"""
3142
try:
3243
(mod_name, cls_name) = parse_class_name(kind)
3344
module = import_module(mod_name)
3445
class_ = getattr(module, cls_name)
3546
return cast(Type[Plugin], class_)
3647
except Exception:
37-
logger.exception("Unable to instantiate class '%s'", kind)
48+
logger.exception("Unable to import plugin type '%s'", kind)
3849
raise
3950

4051
def __register_plugin_type(self, kind: str) -> None:
52+
"""Register a plugin type.
53+
54+
Args:
55+
kind: The fully-qualified type of the plugin to be registered.
56+
"""
4157
if kind not in self._plugin_types:
4258
plugin_type = self.__get_plugin_type(kind)
4359
self._plugin_types[kind] = plugin_type

mcpgateway/plugins/framework/manager.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35-
T = TypeVar('T')
35+
T = TypeVar("T")
3636

3737

3838
class PluginExecutor(Generic[T]):
3939
"""Executes a list of plugins."""
40+
4041
async def execute(
4142
self,
4243
plugins: list[PluginRef],
@@ -45,15 +46,15 @@ async def execute(
4546
plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]],
4647
compare: Callable[[T, list[PluginCondition], GlobalContext], bool],
4748
local_contexts: Optional[PluginContextTable] = None,
48-
) -> tuple[PluginResult[T] | None, PluginContextTable | None]:
49+
) -> tuple[PluginResult[T], PluginContextTable | None]:
4950
"""Execute a plugins hook run before a prompt is retrieved and rendered.
5051
5152
Args:
5253
plugins: the list of plugins to execute.
5354
payload: the payload to be analyzed.
5455
global_context: contextual information for all plugins.
5556
plugin_run: async function for executing plugin hook.
56-
compare: function for comparing conditional information with context and payload
57+
compare: function for comparing conditional information with context and payload.
5758
local_contexts: context local to a single plugin.
5859
5960
Returns:
@@ -85,11 +86,11 @@ async def execute(
8586
if not result.continue_processing:
8687
# Check execution mode
8788
if pluginref.plugin.mode == PluginMode.ENFORCE:
88-
return (PluginResult[T](continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None)
89+
return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), None)
8990
elif pluginref.plugin.mode == PluginMode.PERMISSIVE:
90-
logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}")
91+
logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else ''}")
9192

92-
return (PluginResult[T](continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts)
93+
return (PluginResult[T](continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), res_local_contexts)
9394

9495

9596
async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult:
@@ -176,7 +177,7 @@ async def initialize(self) -> None:
176177
"""
177178
if self._initialized:
178179
return
179-
180+
180181
plugins = self._config.plugins if self._config else []
181182

182183
for plugin_config in plugins:
@@ -204,7 +205,7 @@ async def prompt_pre_fetch(
204205
payload: PromptPrehookPayload,
205206
global_context: GlobalContext,
206207
local_contexts: Optional[PluginContextTable] = None,
207-
) -> tuple[PromptPrehookResult | None, PluginContextTable | None]:
208+
) -> tuple[PromptPrehookResult, PluginContextTable | None]:
208209
"""Plugin hook run before a prompt is retrieved and rendered.
209210
210211
Args:
@@ -218,11 +219,9 @@ async def prompt_pre_fetch(
218219
plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH)
219220
return await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts)
220221

221-
222-
223222
async def prompt_post_fetch(
224223
self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None
225-
) -> tuple[PromptPosthookResult | None, PluginContextTable | None]:
224+
) -> tuple[PromptPosthookResult, PluginContextTable | None]:
226225
"""Plugin hook run after a prompt is rendered.
227226
228227
Args:

mcpgateway/plugins/framework/models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,19 @@ class PluginManifest(BaseModel):
157157
default_config: dict[str, Any]
158158

159159

160-
class PluginError(BaseModel): # (ErrorResponse): # Inherits from MCP error format
161-
"""A plugin error.
160+
class PluginViolation(BaseModel):
161+
"""A plugin filter violation.
162162
163163
Attributes:
164164
plugin_name (str): The name of the plugin.
165-
error_description (str): the error in text.
166-
error_code (str): an error code.
165+
description (str): the violation in text.
166+
violation_code (str): a violation code.
167167
details: (dict[str, Any])
168168
"""
169169

170170
plugin_name: str
171-
error_description: str
172-
error_code: str
171+
description: str
172+
violation_code: str
173173
details: dict[str, Any]
174174

175175

mcpgateway/plugins/framework/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def get_all_plugins(self) -> list[PluginRef]:
106106
"""
107107
return list(self._plugins.values())
108108

109+
@property
109110
def plugin_count(self) -> int:
110111
"""Return the number of plugins registered.
111112

mcpgateway/plugins/framework/types.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# First-Party
1616
from mcpgateway.models import PromptResult
17-
from mcpgateway.plugins.framework.models import PluginError
17+
from mcpgateway.plugins.framework.models import PluginViolation
1818

1919
T = TypeVar("T")
2020

@@ -50,18 +50,18 @@ def __init__(self, name: str, result: PromptResult):
5050
class PluginResult(Generic[T]):
5151
"""A plugin result."""
5252

53-
def __init__(self, continue_processing: bool = True, modified_payload: Optional[T] = None, error: Optional[PluginError] = None, metadata: Optional[dict[str, Any]] = None):
53+
def __init__(self, continue_processing: bool = True, modified_payload: Optional[T] = None, violation: Optional[PluginViolation] = None, metadata: Optional[dict[str, Any]] = None):
5454
"""Initialize a plugin result object.
5555
5656
Args:
5757
continue_processing (bool): Whether to stop processing.
5858
modified_payload (Optional[Any]): The modified payload if the plugin is a transformer.
59-
error (Optional[PluginError]): error object.
59+
violation (Optional[PluginViolation]): violation object.
6060
metadata (Optional[dict[str, Any]]): additional metadata.
6161
"""
6262
self.continue_processing = continue_processing
6363
self.modified_payload = modified_payload
64-
self.error = error
64+
self.violation = violation
6565
self.metadata = metadata or {}
6666

6767

@@ -72,13 +72,7 @@ def __init__(self, continue_processing: bool = True, modified_payload: Optional[
7272
class GlobalContext:
7373
"""The global context, which shared across all plugins."""
7474

75-
def __init__(
76-
self,
77-
request_id: str,
78-
user: Optional[str] = None,
79-
tenant_id: Optional[str] = None,
80-
server_id: Optional[str] = None
81-
) -> None:
75+
def __init__(self, request_id: str, user: Optional[str] = None, tenant_id: Optional[str] = None, server_id: Optional[str] = None) -> None:
8276
"""Initialize a global context.
8377
8478
Args:
@@ -92,6 +86,7 @@ def __init__(
9286
self.tenant_id = tenant_id
9387
self.server_id = server_id
9488

89+
9590
class PluginContext(GlobalContext):
9691
"""The plugin's context, which lasts a request lifecycle.
9792

0 commit comments

Comments
 (0)