From f78e75b01f6d4474778d8a8d79b94375f5b04804 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Wed, 30 Jul 2025 11:54:08 -0600 Subject: [PATCH 01/20] feat: initial revision of adding plugin support. Signed-off-by: Teryl Taylor --- mcpgateway/plugins/framework/base.py | 212 ++++++++++++++++++ mcpgateway/plugins/framework/loader/config.py | 92 ++++++++ mcpgateway/plugins/framework/loader/plugin.py | 59 +++++ mcpgateway/plugins/framework/manager.py | 139 ++++++++++++ mcpgateway/plugins/framework/models.py | 205 +++++++++++++++++ mcpgateway/plugins/framework/registry.py | 107 +++++++++ mcpgateway/plugins/framework/types.py | 131 +++++++++++ mcpgateway/plugins/framework/utils.py | 97 ++++++++ plugins/config.yaml | 37 +++ plugins/regex/plugin-manifest.yaml | 9 + plugins/regex/search_replace.py | 65 ++++++ 11 files changed, 1153 insertions(+) create mode 100644 mcpgateway/plugins/framework/base.py create mode 100644 mcpgateway/plugins/framework/loader/config.py create mode 100644 mcpgateway/plugins/framework/loader/plugin.py create mode 100644 mcpgateway/plugins/framework/manager.py create mode 100644 mcpgateway/plugins/framework/models.py create mode 100644 mcpgateway/plugins/framework/registry.py create mode 100644 mcpgateway/plugins/framework/types.py create mode 100644 mcpgateway/plugins/framework/utils.py create mode 100644 plugins/config.yaml create mode 100644 plugins/regex/plugin-manifest.yaml create mode 100644 plugins/regex/search_replace.py diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py new file mode 100644 index 000000000..3aa2fb06c --- /dev/null +++ b/mcpgateway/plugins/framework/base.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +"""Base plugin implementation. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module implements the base plugin object. +It supports pre and post hooks AI safety, security and business processing +for the following locations in the server: +server_pre_register / server_post_register - for virtual server verification +tool_pre_invoke / tool_post_invoke - for guardrails +prompt_pre_fetch / prompt_post_fetch - for prompt filtering +resource_pre_fetch / resource_post_fetch - for content filtering +auth_pre_check / auth_post_check - for custom auth logic +federation_pre_sync / federation_post_sync - for gateway federation +""" + +# Standard +import uuid + +# First-Party +from mcpgateway.plugins.framework.models import HookType, PluginCondition, PluginConfig, PluginMode +from mcpgateway.plugins.framework.types import ( + PluginContext, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, +) + + +class Plugin: + """Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server.""" + + def __init__(self, config: PluginConfig) -> None: + """Initialize a plugin with a configuration and context. + + Args: + config: The plugin configuration + """ + self._config = config + + @property + def priority(self) -> int: + """Return the plugin's priority. + + Returns: + Plugin's priority. + """ + return self._config.priority + + @property + def mode(self) -> PluginMode: + """Return the plugin's mode. + + Returns: + Plugin's mode. + """ + return self._config.mode + + @property + def name(self) -> str: + """Return the plugin's name. + + Returns: + Plugin's name. + """ + return self._config.name + + @property + def hooks(self) -> list[HookType]: + """Return the plugin's currently configured hooks. + + Returns: + Plugin's configured hooks. + """ + return self._config.hooks + + @property + def tags(self) -> list[str]: + """Return the plugin's tags. + + Returns: + Plugin's tags. + """ + return self._config.tags + + @property + def conditions(self) -> list[PluginCondition] | None: + """Return the plugin's conditions for operation. + + Returns: + Plugin's conditions for executing. + """ + return self._config.conditions + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """Plugin hook run before a prompt is retrieved and rendered. + + Args: + payload: The prompt payload to be analyzed. + context: contextual information about the hook call. Including why it was called. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError(f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """) + + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Plugin hook run after a prompt is rendered. + + Args: + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError(f"""'prompt_post_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """) + + def shutdown(self) -> None: + """Plugin cleanup code.""" + + +class PluginRef: + """Plugin reference which contains a uuid.""" + + def __init__(self, plugin: Plugin): + """Initialize a plugin reference. + + Args: + plugin: The plugin to reference. + """ + self._plugin = plugin + self._uuid = uuid.uuid4() + + @property + def plugin(self) -> Plugin: + """Return the underlying plugin. + + Returns: + The underlying plugin. + """ + return self._plugin + + @property + def uuid(self) -> str: + """Return the plugin's UUID. + + Returns: + Plugin's UUID. + """ + return self._uuid.hex + + @property + def priority(self) -> int: + """Returns the plugin's priority. + + Returns: + Plugin's priority. + """ + return self._plugin.priority + + @property + def name(self) -> str: + """Return the plugin's name. + + Returns: + Plugin's name. + """ + return self._plugin.name + + @property + def hooks(self) -> list[HookType]: + """Returns the plugin's currently configured hooks. + + Returns: + Plugin's configured hooks. + """ + return self._plugin.hooks + + @property + def tags(self) -> list[str]: + """Return the plugin's tags. + + Returns: + Plugin's tags. + """ + return self._plugin.tags + + @property + def conditions(self) -> list[PluginCondition] | None: + """Return the plugin's conditions for operation. + + Returns: + Plugin's conditions for operation. + """ + return self._plugin.conditions + + @property + def mode(self) -> PluginMode: + """Return the plugin's mode. + + Returns: + Plugin's mode. + """ + return self.plugin.mode diff --git a/mcpgateway/plugins/framework/loader/config.py b/mcpgateway/plugins/framework/loader/config.py new file mode 100644 index 000000000..d9cc77f34 --- /dev/null +++ b/mcpgateway/plugins/framework/loader/config.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +"""Configuration loader implementation. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module loads configurations for plugins. +""" + +# Standard +import os + +# Third-Party +import jinja2 +import yaml + +# First-Party +from mcpgateway.plugins.framework.models import Config, PluginConfig, PluginManifest + + +class ConfigLoader: + """A configuration loader.""" + + @staticmethod + def load_config(config: str, use_jinja: bool = True) -> Config: + """Load the plugin configuration from a file path. + + Args: + config: the configuration path. + use_jinja: use jinja to replace env variables if true. + + Returns: + The plugin configuration object. + """ + with open(os.path.normpath(config), "r", encoding="utf-8") as file: + template = file.read() + if use_jinja: + jinja_env = jinja2.Environment(loader=jinja2.BaseLoader()) + rendered_template = jinja_env.from_string(template).render(env=os.environ) + else: + rendered_template = template + config_data = yaml.safe_load(rendered_template) + return Config(**config_data) + + @staticmethod + def dump_config(path: str, config: Config) -> None: + """Dump plugin configuration to a file. + + Args: + path: configuration file path + config: the plugin configuration path + """ + with open(os.path.normpath(path), "w", encoding="utf-8") as file: + yaml.safe_dump(config.model_dump(exclude_none=True), file) + + @staticmethod + def load_plugin_config(config: str) -> PluginConfig: + """Load a plugin configuration from a file path. + + This function autoescapes curly brackets in the 'instruction' + and 'examples' keys under the config attribute. + + Args: + config: the plugin configuration path + + Returns: + The plugin configuration object + """ + with open(os.path.normpath(config), "r", encoding="utf8") as file: + template = file.read() + jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + rendered_template = jinja_env.from_string(template).render(env=os.environ) + config_data = yaml.safe_load(rendered_template) + return PluginConfig(**config_data) + + @staticmethod + def load_plugin_manifest(manifest: str) -> PluginManifest: + """Load a plugin manifest from a file path. + + Args: + manifest: the plugin manifest path + + Returns: + The plugin manifest object + """ + with open(os.path.normpath(manifest), "r", encoding="utf8") as file: + template = file.read() + jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + rendered_template = jinja_env.from_string(template).render(env=os.environ) + config_data = yaml.safe_load(rendered_template) + return PluginManifest(**config_data) diff --git a/mcpgateway/plugins/framework/loader/plugin.py b/mcpgateway/plugins/framework/loader/plugin.py new file mode 100644 index 000000000..ae81c5aaf --- /dev/null +++ b/mcpgateway/plugins/framework/loader/plugin.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +"""Plugin loader implementation. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module implements the plugin loader. +""" + +# Standard +import logging +from typing import cast, Type + +# First-Party +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig +from mcpgateway.plugins.framework.utils import import_module, parse_class_name + +logger = logging.getLogger(__name__) + + +class PluginLoader(object): + """A plugin loader object for loading and instantiating plugins.""" + + def __init__(self) -> None: + """Initialize the plugin loader.""" + self._plugin_types: dict[str, Type[Plugin]] = {} + + def __get_plugin_type(self, kind: str) -> Type[Plugin]: + try: + (mod_name, cls_name) = parse_class_name(kind) + module = import_module(mod_name) + class_ = getattr(module, cls_name) + return cast(Type[Plugin], class_) + except Exception: + logger.exception("Unable to instantiate class '%s'", kind) + raise + + def __register_plugin_type(self, kind: str) -> None: + if kind not in self._plugin_types: + plugin_type = self.__get_plugin_type(kind) + self._plugin_types[kind] = plugin_type + + async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | None: + """Load and instantiate a plugin, given a configuration. + + Args: + config: A plugin configuration. + + Returns: + A plugin instance. + """ + if config.kind not in self._plugin_types: + self.__register_plugin_type(config.kind) + plugin_type = self._plugin_types[config.kind] + if plugin_type: + return plugin_type(config) + return None diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py new file mode 100644 index 000000000..aed27f19c --- /dev/null +++ b/mcpgateway/plugins/framework/manager.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +"""Plugin manager. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Module that manages and calls plugins at hookpoints throughout the gateway. +""" + +# Standard +import logging +from typing import Optional + +# First-Party +from mcpgateway.plugins.framework.loader.config import ConfigLoader +from mcpgateway.plugins.framework.loader.plugin import PluginLoader +from mcpgateway.plugins.framework.models import Config, HookType, PluginMode +from mcpgateway.plugins.framework.registry import PluginInstanceRegistry +from mcpgateway.plugins.framework.types import ( + GlobalContext, + PluginContext, + PluginContextTable, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, +) +from mcpgateway.plugins.framework.utils import pre_prompt_matches + +logger = logging.getLogger(__name__) + + +class PluginManager: + """Plugin manager for managing the plugin lifecycle.""" + + def __init__(self, config: str): + """Initialize plugin manager. + + Args: + config: plugin configuration path. + """ + self._config: Config = ConfigLoader.load_config(config) + self._initialized: bool = False + self._loader: PluginLoader = PluginLoader() + self._registry: PluginInstanceRegistry = PluginInstanceRegistry() + + @property + def config(self) -> Config: + """Plugin manager configuration. + + Returns: + The plugin configuration. + """ + return self._config + + async def initialize(self) -> None: + """Initialize the plugin manager. + + Raises: + ValueError: if it cannot initialize the plugin. + """ + if self._initialized: + return + + for plugin_config in self._config.plugins: + if plugin_config.mode != PluginMode.DISABLED: + plugin = await self._loader.load_and_instantiate_plugin(plugin_config) + if plugin: + self._registry.register(plugin) + else: + raise ValueError(f"Unable to register and initialize plugin: {plugin_config.name}") + self._initialized = True + logger.info(f"Plugin manager initialized with {len(self._registry.get_all_plugins())} plugins") + + async def prompt_pre_fetch( + self, + payload: PromptPrehookPayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, + ) -> tuple[PromptPrehookResult | None, PluginContextTable | None]: + """Plugin hook run before a prompt is retrieved and rendered. + + Args: + payload: The prompt payload to be analyzed. + global_context: contextual information for all plugins. + local_contexts: context local to a single plugin. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + + if not plugins: + return (PromptPrehookResult(modified_payload=payload), None) + + res_local_contexts = {} + combined_metadata = {} + current_payload: PromptPrehookPayload | None = None + for pluginref in plugins: + if not pluginref.conditions or not pre_prompt_matches(payload, pluginref.conditions, global_context): + continue + local_context_key = global_context.request_id + pluginref.uuid + if local_contexts and local_context_key in local_contexts: + local_context = local_contexts[local_context_key] + else: + local_context = PluginContext(global_context) + res_local_contexts[local_context_key] = local_context + result = await pluginref.plugin.prompt_pre_fetch(payload, local_context) + + if result.metadata: + combined_metadata.update(result.metadata) + + if result.modified_payload is not None: + current_payload = result.modified_payload + + if not result.continue_processing: + # Check execution mode + if pluginref.plugin.mode == PluginMode.ENFORCE: + return (PromptPrehookResult(continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None) + elif pluginref.plugin.mode == PluginMode.PERMISSIVE: + logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}") + + return (PromptPrehookResult(continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts) + + async def prompt_post_fetch( + self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None + ) -> tuple[PromptPosthookResult | None, PluginContextTable | None]: + """Plugin hook run after a prompt is rendered. + + Args: + payload: The prompt payload to be analyzed. + global_context: contextual information for all plugins. + local_contexts: context local to a single plugin. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + return (None, None) diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py new file mode 100644 index 000000000..b4715a280 --- /dev/null +++ b/mcpgateway/plugins/framework/models.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +"""Pydantic models for plugins. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import BaseModel + + +class HookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + prompt_pre_fetch: The prompt pre hook. + prompt_post_fetch: The prompt post hook. + """ + + PROMPT_PRE_FETCH = "prompt_pre_fetch" + PROMPT_POST_FETCH = "prompt_post_fetch" + + +class PluginMode(str, Enum): + """Plugin modes of operation. + + Attributes: + enforce: enforces the plugin result. + permissive: audits the result. + disabled: plugin disabled. + """ + + ENFORCE = "enforce" + PERMISSIVE = "permissive" + DISABLED = "disabled" + + +class ToolTemplate(BaseModel): + """Tool Template. + + Attributes: + tool_name (str): the name of the tool. + fields (Optional[list[str]]): the tool fields that are affected. + result (bool): analyze tool output if true. + """ + + tool_name: str + fields: Optional[list[str]] = None + result: bool = False + + +class PromptTemplate(BaseModel): + """Prompt Template. + + Attributes: + prompt_name (str): the name of the prompt. + fields (Optional[list[str]]): the prompt fields that are affected. + result (bool): analyze tool output if true. + """ + + prompt_name: str + fields: Optional[list[str]] = None + result: bool = False + + +class PluginCondition(BaseModel): + """Conditions for when plugin should execute. + + Attributes: + server_ids (Optional[set[str]]): set of server ids. + tenant_ids (Optional[set[str]]): set of tenant ids. + tools (Optional[set[str]]): set of tool names. + prompts (Optional[set[str]]): set of prompt names. + user_pattern (Optional[list[str]]): list of user patterns. + content_types (Optional[list[str]]): list of content types. + """ + + server_ids: Optional[set[str]] = None + tenant_ids: Optional[set[str]] = None + tools: Optional[set[str]] = None + prompts: Optional[set[str]] = None + user_patterns: Optional[list[str]] = None + content_types: Optional[list[str]] = None + + +class AppliedTo(BaseModel): + """What tools/prompts and fields the plugin will be applied to. + + Attributes: + tools (Optional[list[ToolTemplate]]): tools and fields to be applied. + prompts (Optional[list[PromptTemplate]]): prompts and fields to be applied. + """ + + tools: Optional[list[ToolTemplate]] = None + prompts: Optional[list[PromptTemplate]] = None + + +class PluginConfig(BaseModel): + """A plugin configuration. + + Attributes: + name (str): The unique name of the plugin. + description (str): A description of the plugin. + author (str): The author of the plugin. + kind (str): The kind or type of plugin. Usually a fully qualified object type. + namespace (str): The namespace where the plugin resides. + version (str): version of the plugin. + hooks (list[str]): a list of the hook points where the plugin will be called. + tags (list[str]): a list of tags for making the plugin searchable. + mode (bool): whether the plugin is active. + priority (int): indicates the order in which the plugin is run. Lower = higher priority. + conditions (Optional[list[PluginCondition]]): the conditions on which the plugin is run. + applied_to (Optional[list[AppliedTo]]): the tools, fields, that the plugin is applied to. + config (dict[str, Any]): the plugin specific configurations. + """ + + name: str + description: str + author: str + kind: str + namespace: Optional[str] = None + version: str + hooks: list[HookType] + tags: list[str] + mode: PluginMode = PluginMode.ENFORCE + priority: int = 100 # Lower = higher priority + conditions: Optional[list[PluginCondition]] = None # When to apply + applied_to: Optional[list[AppliedTo]] = None # Fields to apply to. + config: dict[str, Any] = {} + + +class PluginManifest(BaseModel): + """Plugin manifest. + + Attributes: + description (str): A description of the plugin. + author (str): The author of the plugin. + version (str): version of the plugin. + tags (list[str]): a list of tags for making the plugin searchable. + available_hooks (list[str]): a list of the hook points where the plugin is callable. + default_config (dict[str, Any]): the default configurations. + """ + + description: str + author: str + version: str + tags: list[str] + available_hooks: list[str] + default_config: dict[str, Any] + + +class PluginError(BaseModel): # (ErrorResponse): # Inherits from MCP error format + """A plugin error. + + Attributes: + plugin_name (str): The name of the plugin. + error_description (str): the error in text. + error_code (str): an error code. + details: (dict[str, Any]) + """ + + plugin_name: str + error_description: str + error_code: str + details: dict[str, Any] + + +class PluginSettings(BaseModel): + """Global plugin settings. + + Attributes: + parallel_execution_within_band (bool): execute plugins with same priority in parallel. + plugin_timeout (int): timeout value for plugins operations. + fail_on_plugin_error (bool): error when there is a plugin connectivity or ignore. + enable_plugin_api (bool): enable or disable plugins globally. + plugin_health_check_interval (int): health check interval check. + """ + + parallel_execution_within_band: bool = False + plugin_timeout: int = 30 + fail_on_plugin_error: bool = False + enable_plugin_api: bool = False + plugin_health_check_interval: int = 60 + + +class Config(BaseModel): + """Configurations for plugins. + + Attributes: + plugins: the list of plugins to enable. + plugin_dirs: The directories in which to look for plugins. + plugin_settings: global settings for plugins. + """ + + plugins: list[PluginConfig] = [] + plugin_dirs: list[str] = [] + plugin_settings: PluginSettings diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py new file mode 100644 index 000000000..6f85627ea --- /dev/null +++ b/mcpgateway/plugins/framework/registry.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +"""Plugin instance registry. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Module that stores plugin instances and manages hook points. +""" + +# Standard +from collections import defaultdict +import logging +from typing import Optional + +# First-Party +from mcpgateway.plugins.framework.base import Plugin, PluginRef +from mcpgateway.plugins.framework.models import HookType + +logger = logging.getLogger(__name__) + + +class PluginInstanceRegistry: + """Registry for managing loaded plugins.""" + + def __init__(self) -> None: + """Initialize a plugin instance registry.""" + self._plugins: dict[str, PluginRef] = {} + self._hooks: dict[HookType, list[PluginRef]] = defaultdict(list) + self._priority_cache: dict[HookType, list[PluginRef]] = {} + + def register(self, plugin: Plugin) -> None: + """Register a plugin instance. + + Args: + plugin: plugin to be registered. + + Raises: + ValueError: if plugin is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"Plugin {plugin.name} already registered") + + plugin_ref = PluginRef(plugin) + + self._plugins[plugin.name] = plugin_ref + + # Register hooks + for hook_type in plugin.hooks: + self._hooks[hook_type].append(plugin_ref) + # Invalidate priority cache for this hook + self._priority_cache.pop(hook_type, None) + + logger.info(f"Registered plugin: {plugin.name} with hooks: {[h.name for h in plugin.hooks]}") + + def unregister(self, plugin_name: str) -> None: + """Unregister a plugin given its name. + + Args: + plugin_name: The name of the plugin to unregister. + + Returns: + None + """ + if plugin_name not in self._plugins: + return + + plugin = self._plugins.pop(plugin_name) + # Remove from hooks + for hook_type in plugin.hooks: + self._hooks[hook_type] = [p for p in self._hooks[hook_type] if p.name != plugin_name] + self._priority_cache.pop(hook_type, None) + + logger.info(f"Unregistered plugin: {plugin_name}") + + def get_plugin(self, name: str) -> Optional[PluginRef]: + """Get a plugin by name. + + Args: + name: the name of the plugin to return. + + Returns: + A plugin. + """ + return self._plugins.get(name) + + def get_plugins_for_hook(self, hook_type: HookType) -> list[PluginRef]: + """Get all plugins for a specific hook, sorted by priority. + + Args: + hook_type: the hook type. + + Returns: + A list of plugin instances. + """ + if hook_type not in self._priority_cache: + plugins = sorted(self._hooks[hook_type], key=lambda p: p.priority) + self._priority_cache[hook_type] = plugins + return self._priority_cache[hook_type] + + def get_all_plugins(self) -> list[PluginRef]: + """Get all registered plugin instances. + + Returns: + A list of registered plugin instances. + """ + return list(self._plugins.values()) diff --git a/mcpgateway/plugins/framework/types.py b/mcpgateway/plugins/framework/types.py new file mode 100644 index 000000000..65909eea2 --- /dev/null +++ b/mcpgateway/plugins/framework/types.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +"""Pydantic models for plugins. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from typing import Any, Generic, Optional, TypeVar + +# First-Party +from mcpgateway.models import PromptResult +from mcpgateway.plugins.framework.models import PluginError + +T = TypeVar("T") + + +class PromptPrehookPayload: + """A prompt payload for a prompt prehook.""" + + def __init__(self, name: str, args: Optional[dict[str, str]]): + """Initialize a prompt prehook payload. + + Args: + name: The prompt name. + args: The prompt arguments for rendering. + """ + self.name = name + self.args = args + + +PromptPosthookPayload = PromptResult + + +class PluginResult(Generic[T]): + """A plugin result.""" + + def __init__(self, continue_processing: bool = True, modified_payload: Optional[T] = None, error: Optional[PluginError] = None, metadata: Optional[dict[str, Any]] = None): + """Initialize a plugin result object. + + Args: + continue_processing (bool): Whether to stop processing. + modified_payload (Optional[Any]): The modified payload if the plugin is a transformer. + error (Optional[PluginError]): error object. + metadata (Optional[dict[str, Any]]): additional metadata. + """ + self.continue_processing = continue_processing + self.modified_payload = modified_payload + self.error = error + self.metadata = metadata or {} + + +PromptPrehookResult = PluginResult[PromptPrehookPayload] +PromptPosthookResult = PluginResult[PromptPosthookPayload] + + +class GlobalContext: + """The global context, which shared across all plugins.""" + + def __init__( + self, + request_id: str, + user: Optional[str] = None, + tenant_id: Optional[str] = None, + server_id: Optional[str] = None, + ) -> None: + """Initialize a global context. + + Args: + request_id (str): ID of the HTTP request. + user (str): user ID associated with the request. + tenant_id (str): tenant ID. + server_id (str): server ID. + """ + self.request_id = request_id + self.user = user + self.tenant_id = tenant_id + self.server_id = server_id + + +class PluginContext(GlobalContext): + """The plugin's context, which lasts a request lifecycle. + + Attributes: + metadata: context metadata. + state: the inmemory state of the request. + """ + + def __init__(self, gcontext: Optional[GlobalContext] = None) -> None: + """Initialize a plugin context. + + Args: + gcontext: the global context object. + """ + if gcontext: + super().__init__(gcontext.request_id, gcontext.user, gcontext.tenant_id, gcontext.server_id) + self.state: dict[str, Any] = {} # In-memory state + self.metadata: dict[str, Any] = {} + + def get_state(self, key: str, default: Any = None) -> Any: + """Get value from shared state. + + Args: + key: The key to access the shared state. + default: A default value if one doesn't exist. + + Returns: + The state value. + """ + return self.state.get(key, default) + + def set_state(self, key: str, value: Any) -> None: + """Set value in shared state. + + Args: + key: the key to add to the state. + value: the value to add to the state. + """ + self.state[key] = value + + async def cleanup(self) -> None: + """Cleanup context resources.""" + self.state.clear() + self.metadata.clear() + + +PluginContextTable = dict[str, PluginContext] diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py new file mode 100644 index 000000000..331c5fd9f --- /dev/null +++ b/mcpgateway/plugins/framework/utils.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Utility module for plugins layer. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module implements the utility functions associated with +plugins. +""" + +# Standard +from functools import cache +import importlib +from types import ModuleType + +# First-Party +from mcpgateway.plugins.framework.models import PluginCondition +from mcpgateway.plugins.framework.types import GlobalContext, PromptPrehookPayload + + +@cache # noqa +def import_module(mod_name: str) -> ModuleType: + """Import a module. + + Args: + mod_name: fully qualified module name + + Returns: + A module. + """ + return importlib.import_module(mod_name) + + +def parse_class_name(name: str) -> tuple[str, str]: + """Parse a class name into its constituents. + + Args: + name: the qualified class name + + Returns: + A pair containing the qualified class prefix and the class name + """ + clslist = name.rsplit(".", 1) + if len(clslist) == 2: + return (clslist[0], clslist[1]) + return ("", name) + + +def matches(condition: PluginCondition, context: GlobalContext) -> bool: + """Check if conditions match the current context. + + Args: + condition: the conditions on the plugin that are required for execution. + context: the global context. + + Returns: + True if the plugin matches criteria. + """ + # Check server ID + if condition.server_ids and context.server_id not in condition.server_ids: + return False + + # Check tenant ID + if condition.tenant_ids and context.tenant_id not in condition.tenant_ids: + return False + + # Check user patterns (simple contains check, could be regex) + if condition.user_patterns and context.user: + if not any(pattern in context.user for pattern in condition.user_patterns): + return False + return True + + +def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: + """Check for a match on pre-prompt hooks. + + Args: + payload: the prompt prehook payload. + conditions: the conditions on the plugin that are required for execution. + context: the global context. + + Returns: + True if the plugin matches criteria. + """ + current_result = True + for index, condition in enumerate(conditions): + if not matches(condition, context): + current_result = False + + if condition.prompts and payload.name not in condition.prompts: + current_result = False + if current_result: + return True + elif index < len(conditions) - 1: + current_result = True + return current_result diff --git a/plugins/config.yaml b/plugins/config.yaml new file mode 100644 index 000000000..0a3a682d9 --- /dev/null +++ b/plugins/config.yaml @@ -0,0 +1,37 @@ +# plugins/config.yaml - Main plugin configuration file + +plugins: + # Self-contained Search Replace Plugin + - name: "ReplaceBadWordsPlugin" + kind: "plugins.regex.search_replace.SearchReplacePlugin" + description: "A plugin for finding and replacing words." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: crap + replace: crud + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/plugins/regex/plugin-manifest.yaml b/plugins/regex/plugin-manifest.yaml new file mode 100644 index 000000000..ad79264c3 --- /dev/null +++ b/plugins/regex/plugin-manifest.yaml @@ -0,0 +1,9 @@ +description: "Search replace plugin manifest." +author: "MCP Context Forge Team" +version: "0.1" +available_hooks: + - "prompt_pre_hook" + - "prompt_post_hook" + - "tool_pre_hook" + - "tool_post_hook" +default_configs: \ No newline at end of file diff --git a/plugins/regex/search_replace.py b/plugins/regex/search_replace.py new file mode 100644 index 000000000..6ee0c21c7 --- /dev/null +++ b/plugins/regex/search_replace.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""Simple example plugin for searching and replacing text. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +This module loads configurations for plugins. +""" +# Standard +import re + +# Third-Party +from pydantic import BaseModel + +# First-Party +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig +from mcpgateway.plugins.framework.types import PluginContext, PromptPrehookPayload, PromptPrehookResult + + +class SearchReplace(BaseModel): + search: str + replace: str + +class SearchReplaceConfig(BaseModel): + words: list[SearchReplace] + + + +class SearchReplacePlugin(Plugin): + """Example search replace plugin""" + def __init__(self, config: PluginConfig): + super().__init__(config) + self._srconfig = SearchReplaceConfig.model_validate(self._config.config) + self.__patterns = [] + for word in self._srconfig.words: + self.__patterns.append((r'{}'.format(word.search), word.replace)) + + + + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """The plugin hook run before a prompt is retrieved and rendered. + Args: + payload: The prompt payload to be analyzed. + context: contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + if payload.args: + for pattern in self.__patterns: + for key in payload.args: + value = re.sub( + pattern[0], + pattern[1], + payload.args[key] + ) + payload.args[key] = value + return PromptPrehookResult(modified_payload=payload) + + + + From da2f63b5d7ec7653ae8792ff9684ea4d49651677 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Thu, 31 Jul 2025 21:19:50 -0600 Subject: [PATCH 02/20] feat(plugins): added prompt posthook functionality with executor, fixed some linting issues, updated example plugin with posthook. Signed-off-by: Teryl Taylor --- mcpgateway/plugins/framework/loader/config.py | 2 +- mcpgateway/plugins/framework/manager.py | 180 ++++++++++++++---- mcpgateway/plugins/framework/registry.py | 8 + mcpgateway/plugins/framework/types.py | 18 +- mcpgateway/plugins/framework/utils.py | 26 ++- plugins/config.yaml | 2 + plugins/regex/plugin-manifest.yaml | 2 +- plugins/regex/search_replace.py | 25 ++- 8 files changed, 212 insertions(+), 51 deletions(-) diff --git a/mcpgateway/plugins/framework/loader/config.py b/mcpgateway/plugins/framework/loader/config.py index d9cc77f34..d314c2e3c 100644 --- a/mcpgateway/plugins/framework/loader/config.py +++ b/mcpgateway/plugins/framework/loader/config.py @@ -36,7 +36,7 @@ def load_config(config: str, use_jinja: bool = True) -> Config: with open(os.path.normpath(config), "r", encoding="utf-8") as file: template = file.read() if use_jinja: - jinja_env = jinja2.Environment(loader=jinja2.BaseLoader()) + jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) rendered_template = jinja_env.from_string(template).render(env=os.environ) else: rendered_template = template diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index aed27f19c..ab2a61c1c 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -10,43 +10,139 @@ # Standard import logging -from typing import Optional +from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar # First-Party +from mcpgateway.plugins.framework.base import PluginRef from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import Config, HookType, PluginMode +from mcpgateway.plugins.framework.models import Config, HookType, PluginCondition, PluginMode from mcpgateway.plugins.framework.registry import PluginInstanceRegistry from mcpgateway.plugins.framework.types import ( GlobalContext, PluginContext, PluginContextTable, + PluginResult, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, PromptPrehookResult, ) -from mcpgateway.plugins.framework.utils import pre_prompt_matches +from mcpgateway.plugins.framework.utils import post_prompt_matches, pre_prompt_matches logger = logging.getLogger(__name__) +T = TypeVar('T') + + +class PluginExecutor(Generic[T]): + """Executes a list of plugins.""" + async def execute( + self, + plugins: list[PluginRef], + payload: T, + global_context: GlobalContext, + plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]], + compare: Callable[[T, list[PluginCondition], GlobalContext], bool], + local_contexts: Optional[PluginContextTable] = None, + ) -> tuple[PluginResult[T] | None, PluginContextTable | None]: + """Execute a plugins hook run before a prompt is retrieved and rendered. + + Args: + plugins: the list of plugins to execute. + payload: the payload to be analyzed. + global_context: contextual information for all plugins. + plugin_run: async function for executing plugin hook. + compare: function for comparing conditional information with context and payload + local_contexts: context local to a single plugin. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + if not plugins: + return (PluginResult[T](modified_payload=None), None) + + res_local_contexts = {} + combined_metadata = {} + current_payload: T | None = None + for pluginref in plugins: + if not pluginref.conditions or not compare(payload, pluginref.conditions, global_context): + continue + local_context_key = global_context.request_id + pluginref.uuid + if local_contexts and local_context_key in local_contexts: + local_context = local_contexts[local_context_key] + else: + local_context = PluginContext(global_context) + res_local_contexts[local_context_key] = local_context + result = await plugin_run(pluginref, payload, local_context) + + if result.metadata: + combined_metadata.update(result.metadata) + + if result.modified_payload is not None: + current_payload = result.modified_payload + + if not result.continue_processing: + # Check execution mode + if pluginref.plugin.mode == PluginMode.ENFORCE: + return (PluginResult[T](continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None) + elif pluginref.plugin.mode == PluginMode.PERMISSIVE: + logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}") + + return (PluginResult[T](continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts) + + +async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """Call plugin's prompt pre-fetch hook. + + Args: + plugin: the plugin to execute. + payload: the prompt payload to be analyzed. + context: contextual information about the hook call. Including why it was called. + + Returns: + The result of the plugin execution. + """ + return await plugin.plugin.prompt_pre_fetch(payload, context) + + +async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Call plugin's prompt post-fetch hook. + + Args: + plugin: the plugin to execute. + payload: the prompt payload to be analyzed. + context: contextual information about the hook call. Including why it was called. + + Returns: + The result of the plugin execution. + """ + return await plugin.plugin.prompt_post_fetch(payload, context) + class PluginManager: """Plugin manager for managing the plugin lifecycle.""" - def __init__(self, config: str): + __shared_state: dict[Any, Any] = {} + _loader: PluginLoader = PluginLoader() + _initialized: bool = False + _registry: PluginInstanceRegistry = PluginInstanceRegistry() + _config: Config | None = None + _pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]() + _post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]() + + def __init__(self, config: str = ""): """Initialize plugin manager. Args: config: plugin configuration path. """ - self._config: Config = ConfigLoader.load_config(config) - self._initialized: bool = False - self._loader: PluginLoader = PluginLoader() - self._registry: PluginInstanceRegistry = PluginInstanceRegistry() + self.__dict__ = self.__shared_state + if config: + self._config = ConfigLoader.load_config(config) @property - def config(self) -> Config: + def config(self) -> Config | None: """Plugin manager configuration. Returns: @@ -54,6 +150,24 @@ def config(self) -> Config: """ return self._config + @property + def plugin_count(self) -> int: + """Number of plugins loaded. + + Returns: + The number of plugins loaded. + """ + return self._registry.plugin_count + + @property + def initialized(self) -> bool: + """Plugin manager initialized. + + Returns: + True if the plugin manager is initialized. + """ + return self._initialized + async def initialize(self) -> None: """Initialize the plugin manager. @@ -62,8 +176,10 @@ async def initialize(self) -> None: """ if self._initialized: return + + plugins = self._config.plugins if self._config else [] - for plugin_config in self._config.plugins: + for plugin_config in plugins: if plugin_config.mode != PluginMode.DISABLED: plugin = await self._loader.load_and_instantiate_plugin(plugin_config) if plugin: @@ -73,6 +189,16 @@ async def initialize(self) -> None: self._initialized = True logger.info(f"Plugin manager initialized with {len(self._registry.get_all_plugins())} plugins") + async def shutdown(self) -> None: + """Shutdown all plugins.""" + for plugin_ref in self._registry.get_all_plugins(): + try: + await plugin_ref.plugin.shutdown() + except Exception as e: + logger.error(f"Error shutting down plugin {plugin_ref.plugin.name}: {e}") + + self._initialized = False + async def prompt_pre_fetch( self, payload: PromptPrehookPayload, @@ -90,38 +216,9 @@ async def prompt_pre_fetch( The result of the plugin's analysis, including whether the prompt can proceed. """ plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + return await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts) - if not plugins: - return (PromptPrehookResult(modified_payload=payload), None) - - res_local_contexts = {} - combined_metadata = {} - current_payload: PromptPrehookPayload | None = None - for pluginref in plugins: - if not pluginref.conditions or not pre_prompt_matches(payload, pluginref.conditions, global_context): - continue - local_context_key = global_context.request_id + pluginref.uuid - if local_contexts and local_context_key in local_contexts: - local_context = local_contexts[local_context_key] - else: - local_context = PluginContext(global_context) - res_local_contexts[local_context_key] = local_context - result = await pluginref.plugin.prompt_pre_fetch(payload, local_context) - - if result.metadata: - combined_metadata.update(result.metadata) - - if result.modified_payload is not None: - current_payload = result.modified_payload - - if not result.continue_processing: - # Check execution mode - if pluginref.plugin.mode == PluginMode.ENFORCE: - return (PromptPrehookResult(continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None) - elif pluginref.plugin.mode == PluginMode.PERMISSIVE: - logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}") - return (PromptPrehookResult(continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts) async def prompt_post_fetch( self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None @@ -136,4 +233,5 @@ async def prompt_post_fetch( Returns: The result of the plugin's analysis, including whether the prompt can proceed. """ - return (None, None) + plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) + return await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts) diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index 6f85627ea..c47da1c79 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -105,3 +105,11 @@ def get_all_plugins(self) -> list[PluginRef]: A list of registered plugin instances. """ return list(self._plugins.values()) + + def plugin_count(self) -> int: + """Return the number of plugins registered. + + Returns: + The number of plugins registered. + """ + return len(self._plugins) diff --git a/mcpgateway/plugins/framework/types.py b/mcpgateway/plugins/framework/types.py index 65909eea2..2e21e73a1 100644 --- a/mcpgateway/plugins/framework/types.py +++ b/mcpgateway/plugins/framework/types.py @@ -30,10 +30,21 @@ def __init__(self, name: str, args: Optional[dict[str, str]]): args: The prompt arguments for rendering. """ self.name = name - self.args = args + self.args = args or {} -PromptPosthookPayload = PromptResult +class PromptPosthookPayload: + """A prompt payload for a prompt posthook.""" + + def __init__(self, name: str, result: PromptResult): + """Initialize a prompt posthook payload. + + Args: + name: The prompt name. + result: The prompt Prompt Result. + """ + self.name = name + self.result = result class PluginResult(Generic[T]): @@ -66,7 +77,7 @@ def __init__( request_id: str, user: Optional[str] = None, tenant_id: Optional[str] = None, - server_id: Optional[str] = None, + server_id: Optional[str] = None ) -> None: """Initialize a global context. @@ -81,7 +92,6 @@ def __init__( self.tenant_id = tenant_id self.server_id = server_id - class PluginContext(GlobalContext): """The plugin's context, which lasts a request lifecycle. diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 331c5fd9f..caac50bff 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -16,7 +16,7 @@ # First-Party from mcpgateway.plugins.framework.models import PluginCondition -from mcpgateway.plugins.framework.types import GlobalContext, PromptPrehookPayload +from mcpgateway.plugins.framework.types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload @cache # noqa @@ -95,3 +95,27 @@ def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCon elif index < len(conditions) - 1: current_result = True return current_result + +def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: + """Check for a match on pre-prompt hooks. + + Args: + payload: the prompt posthook payload. + conditions: the conditions on the plugin that are required for execution. + context: the global context. + + Returns: + True if the plugin matches criteria. + """ + current_result = True + for index, condition in enumerate(conditions): + if not matches(condition, context): + current_result = False + + if condition.prompts and payload.name not in condition.prompts: + current_result = False + if current_result: + return True + elif index < len(conditions) - 1: + current_result = True + return current_result diff --git a/plugins/config.yaml b/plugins/config.yaml index 0a3a682d9..bbf704271 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -20,6 +20,8 @@ plugins: words: - search: crap replace: crud + - search: crud + replace: yikes # Plugin directories to scan diff --git a/plugins/regex/plugin-manifest.yaml b/plugins/regex/plugin-manifest.yaml index ad79264c3..8fc8f1505 100644 --- a/plugins/regex/plugin-manifest.yaml +++ b/plugins/regex/plugin-manifest.yaml @@ -6,4 +6,4 @@ available_hooks: - "prompt_post_hook" - "tool_pre_hook" - "tool_post_hook" -default_configs: \ No newline at end of file +default_configs: diff --git a/plugins/regex/search_replace.py b/plugins/regex/search_replace.py index 6ee0c21c7..3a4bfcad7 100644 --- a/plugins/regex/search_replace.py +++ b/plugins/regex/search_replace.py @@ -16,7 +16,7 @@ # First-Party from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.models import PluginConfig -from mcpgateway.plugins.framework.types import PluginContext, PromptPrehookPayload, PromptPrehookResult +from mcpgateway.plugins.framework.types import PluginContext, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, PromptPrehookResult class SearchReplace(BaseModel): @@ -37,15 +37,16 @@ def __init__(self, config: PluginConfig): for word in self._srconfig.words: self.__patterns.append((r'{}'.format(word.search), word.replace)) - + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: """The plugin hook run before a prompt is retrieved and rendered. + Args: payload: The prompt payload to be analyzed. context: contextual information about the hook call. - + Returns: The result of the plugin's analysis, including whether the prompt can proceed. """ @@ -60,6 +61,24 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC payload.args[key] = value return PromptPrehookResult(modified_payload=payload) + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Plugin hook run after a prompt is rendered. + Args: + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + if payload.result.messages: + for index, message in enumerate(payload.result.messages): + for pattern in self.__patterns: + value = re.sub( + pattern[0], + pattern[1], + message.content.text + ) + payload.result.messages[index] = value + return PromptPosthookResult(modified_payload=payload) From 14c9fae07d54311d708562c87e0805f99be5e613 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Fri, 1 Aug 2025 18:26:16 -0600 Subject: [PATCH 03/20] feat(plugins): integrated plugins into prompt service, fixed linting and type issues. Signed-off-by: Teryl Taylor --- .env.example | 5 ++ MANIFEST.in | 2 + mcpgateway/config.py | 8 ++- mcpgateway/main.py | 13 ++++ mcpgateway/plugins/framework/base.py | 2 +- mcpgateway/plugins/framework/loader/plugin.py | 18 ++++- mcpgateway/plugins/framework/manager.py | 21 +++--- mcpgateway/plugins/framework/models.py | 12 ++-- mcpgateway/plugins/framework/registry.py | 1 + mcpgateway/plugins/framework/types.py | 17 ++--- mcpgateway/plugins/framework/utils.py | 1 + mcpgateway/services/prompt_service.py | 72 ++++++++++++++++++- plugins/regex/search_replace.py | 2 +- 13 files changed, 137 insertions(+), 37 deletions(-) diff --git a/.env.example b/.env.example index 2cb0b19fa..ac938a702 100644 --- a/.env.example +++ b/.env.example @@ -290,3 +290,8 @@ DEBUG=false # Gateway tool name separator GATEWAY_TOOL_NAME_SEPARATOR=- VALID_SLUG_SEPARATOR_REGEXP= r"^(-{1,2}|[_.])$" + +##################################### +# Plugins Settings +##################################### +PLUGINS_ENABLED=false diff --git a/MANIFEST.in b/MANIFEST.in index 71b09ceb0..d810b20a2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -56,6 +56,8 @@ recursive-include alembic *.md recursive-include alembic *.py # recursive-include deployment * # recursive-include mcp-servers * +recursive-include plugins *.py +recursive-include plugins *.yaml # 5️⃣ (Optional) include MKDocs-based docs in the sdist # graft docs diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 1c6938b19..28d4fd8df 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -311,6 +311,10 @@ def _parse_federation_peers(cls, v): use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store json_response_enabled: bool = True # Enable JSON responses instead of SSE streams + # Core plugin settings + plugins_enabled: bool = Field(default=False, description="Enable the plugin framework") + plugin_config_file: str = Field(default="plugins/config.yaml", description="Path to main plugin configuration file") + # Development dev_mode: bool = False reload: bool = False @@ -497,9 +501,7 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = ( - r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" - ) + validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 8412ced1a..a2bf23cce 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -75,6 +75,7 @@ ResourceContent, Root, ) +from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.schemas import ( GatewayCreate, GatewayRead, @@ -160,6 +161,8 @@ else: loop.create_task(bootstrap_db()) +# Initialize plugin manager as a singleton. +plugin_manager: PluginManager | None = PluginManager(settings.plugin_config_file) if settings.plugins_enabled else None # Initialize services tool_service = ToolService() @@ -214,6 +217,9 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: """ logger.info("Starting MCP Gateway services") try: + if plugin_manager: + await plugin_manager.initialize() + logger.info(f"Plugin manager initialized with {plugin_manager.plugin_count} plugins") await tool_service.initialize() await resource_service.initialize() await prompt_service.initialize() @@ -232,6 +238,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: logger.error(f"Error during startup: {str(e)}") raise finally: + # Shutdown plugin manager + if plugin_manager: + try: + await plugin_manager.shutdown() + logger.info("Plugin manager shutdown complete") + except Exception as e: + logger.error(f"Error shutting down plugin manager: {str(e)}") logger.info("Shutting down MCP Gateway services") # await stop_streamablehttp() for service in [resource_cache, sampling_handler, logging_service, completion_service, root_service, gateway_service, prompt_service, resource_service, tool_service, streamable_http_session]: diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 3aa2fb06c..d715ae926 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -123,7 +123,7 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi of plugin type {type(self)} """) - def shutdown(self) -> None: + async def shutdown(self) -> None: """Plugin cleanup code.""" diff --git a/mcpgateway/plugins/framework/loader/plugin.py b/mcpgateway/plugins/framework/loader/plugin.py index ae81c5aaf..2be7e075a 100644 --- a/mcpgateway/plugins/framework/loader/plugin.py +++ b/mcpgateway/plugins/framework/loader/plugin.py @@ -28,16 +28,32 @@ def __init__(self) -> None: self._plugin_types: dict[str, Type[Plugin]] = {} def __get_plugin_type(self, kind: str) -> Type[Plugin]: + """Import a plugin type from a python module. + + Args: + kind: The fully-qualified type of the plugin to be registered. + + Raises: + Exception: if unable to import a module. + + Returns: + A plugin type. + """ try: (mod_name, cls_name) = parse_class_name(kind) module = import_module(mod_name) class_ = getattr(module, cls_name) return cast(Type[Plugin], class_) except Exception: - logger.exception("Unable to instantiate class '%s'", kind) + logger.exception("Unable to import plugin type '%s'", kind) raise def __register_plugin_type(self, kind: str) -> None: + """Register a plugin type. + + Args: + kind: The fully-qualified type of the plugin to be registered. + """ if kind not in self._plugin_types: plugin_type = self.__get_plugin_type(kind) self._plugin_types[kind] = plugin_type diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index ab2a61c1c..0a6d978a2 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -32,11 +32,12 @@ logger = logging.getLogger(__name__) -T = TypeVar('T') +T = TypeVar("T") class PluginExecutor(Generic[T]): """Executes a list of plugins.""" + async def execute( self, plugins: list[PluginRef], @@ -45,7 +46,7 @@ async def execute( plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]], compare: Callable[[T, list[PluginCondition], GlobalContext], bool], local_contexts: Optional[PluginContextTable] = None, - ) -> tuple[PluginResult[T] | None, PluginContextTable | None]: + ) -> tuple[PluginResult[T], PluginContextTable | None]: """Execute a plugins hook run before a prompt is retrieved and rendered. Args: @@ -53,7 +54,7 @@ async def execute( payload: the payload to be analyzed. global_context: contextual information for all plugins. plugin_run: async function for executing plugin hook. - compare: function for comparing conditional information with context and payload + compare: function for comparing conditional information with context and payload. local_contexts: context local to a single plugin. Returns: @@ -85,11 +86,11 @@ async def execute( if not result.continue_processing: # Check execution mode if pluginref.plugin.mode == PluginMode.ENFORCE: - return (PluginResult[T](continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None) + return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), None) elif pluginref.plugin.mode == PluginMode.PERMISSIVE: - logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}") + logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else ''}") - return (PluginResult[T](continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts) + return (PluginResult[T](continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), res_local_contexts) async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: @@ -176,7 +177,7 @@ async def initialize(self) -> None: """ if self._initialized: return - + plugins = self._config.plugins if self._config else [] for plugin_config in plugins: @@ -204,7 +205,7 @@ async def prompt_pre_fetch( payload: PromptPrehookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, - ) -> tuple[PromptPrehookResult | None, PluginContextTable | None]: + ) -> tuple[PromptPrehookResult, PluginContextTable | None]: """Plugin hook run before a prompt is retrieved and rendered. Args: @@ -218,11 +219,9 @@ async def prompt_pre_fetch( plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) return await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts) - - async def prompt_post_fetch( self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None - ) -> tuple[PromptPosthookResult | None, PluginContextTable | None]: + ) -> tuple[PromptPosthookResult, PluginContextTable | None]: """Plugin hook run after a prompt is rendered. Args: diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index b4715a280..d174b80d7 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -157,19 +157,19 @@ class PluginManifest(BaseModel): default_config: dict[str, Any] -class PluginError(BaseModel): # (ErrorResponse): # Inherits from MCP error format - """A plugin error. +class PluginViolation(BaseModel): + """A plugin filter violation. Attributes: plugin_name (str): The name of the plugin. - error_description (str): the error in text. - error_code (str): an error code. + description (str): the violation in text. + violation_code (str): a violation code. details: (dict[str, Any]) """ plugin_name: str - error_description: str - error_code: str + description: str + violation_code: str details: dict[str, Any] diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index c47da1c79..7c2a5dce0 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -106,6 +106,7 @@ def get_all_plugins(self) -> list[PluginRef]: """ return list(self._plugins.values()) + @property def plugin_count(self) -> int: """Return the number of plugins registered. diff --git a/mcpgateway/plugins/framework/types.py b/mcpgateway/plugins/framework/types.py index 2e21e73a1..9ee70ebf8 100644 --- a/mcpgateway/plugins/framework/types.py +++ b/mcpgateway/plugins/framework/types.py @@ -14,7 +14,7 @@ # First-Party from mcpgateway.models import PromptResult -from mcpgateway.plugins.framework.models import PluginError +from mcpgateway.plugins.framework.models import PluginViolation T = TypeVar("T") @@ -50,18 +50,18 @@ def __init__(self, name: str, result: PromptResult): class PluginResult(Generic[T]): """A plugin result.""" - def __init__(self, continue_processing: bool = True, modified_payload: Optional[T] = None, error: Optional[PluginError] = None, metadata: Optional[dict[str, Any]] = None): + def __init__(self, continue_processing: bool = True, modified_payload: Optional[T] = None, violation: Optional[PluginViolation] = None, metadata: Optional[dict[str, Any]] = None): """Initialize a plugin result object. Args: continue_processing (bool): Whether to stop processing. modified_payload (Optional[Any]): The modified payload if the plugin is a transformer. - error (Optional[PluginError]): error object. + violation (Optional[PluginViolation]): violation object. metadata (Optional[dict[str, Any]]): additional metadata. """ self.continue_processing = continue_processing self.modified_payload = modified_payload - self.error = error + self.violation = violation self.metadata = metadata or {} @@ -72,13 +72,7 @@ def __init__(self, continue_processing: bool = True, modified_payload: Optional[ class GlobalContext: """The global context, which shared across all plugins.""" - def __init__( - self, - request_id: str, - user: Optional[str] = None, - tenant_id: Optional[str] = None, - server_id: Optional[str] = None - ) -> None: + def __init__(self, request_id: str, user: Optional[str] = None, tenant_id: Optional[str] = None, server_id: Optional[str] = None) -> None: """Initialize a global context. Args: @@ -92,6 +86,7 @@ def __init__( self.tenant_id = tenant_id self.server_id = server_id + class PluginContext(GlobalContext): """The plugin's context, which lasts a request lifecycle. diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index caac50bff..597d1a7d0 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -96,6 +96,7 @@ def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCon current_result = True return current_result + def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: """Check for a match on pre-prompt hooks. diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 7cd4ff920..bb8af58db 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -20,6 +20,7 @@ import logging from string import Formatter from typing import Any, AsyncGenerator, Dict, List, Optional, Set +import uuid # Third-Party from jinja2 import Environment, meta, select_autoescape @@ -28,9 +29,12 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.config import settings from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate logger = logging.getLogger(__name__) @@ -113,6 +117,7 @@ def __init__(self) -> None: """ self._event_subscribers: List[asyncio.Queue] = [] self._jinja_env = Environment(autoescape=select_autoescape(["html", "xml"]), trim_blocks=True, lstrip_blocks=True) + self._plugin_manager: PluginManager | None = PluginManager() if settings.plugins_enabled else None async def initialize(self) -> None: """Initialize the service.""" @@ -349,13 +354,26 @@ async def list_server_prompts(self, db: Session, server_id: str, include_inactiv prompts = db.execute(query).scalars().all() return [PromptRead.model_validate(self._convert_db_prompt(p)) for p in prompts] - async def get_prompt(self, db: Session, name: str, arguments: Optional[Dict[str, str]] = None) -> PromptResult: + async def get_prompt( + self, + db: Session, + name: str, + arguments: Optional[Dict[str, str]] = None, + user: Optional[str] = None, + tenant_id: Optional[str] = None, + server_id: Optional[str] = None, + request_id: Optional[str] = None, + ) -> PromptResult: """Get a prompt template and optionally render it. Args: db: Database session name: Name of prompt to get arguments: Optional arguments for rendering + user: Optional user identifier for plugin context + tenant_id: Optional tenant identifier for plugin context + server_id: Optional server identifier for plugin context + request_id: Optional request ID, generated if not provided Returns: Prompt result with rendered messages @@ -376,6 +394,34 @@ async def get_prompt(self, db: Session, name: str, arguments: Optional[Dict[str, ... except Exception: ... pass """ + + if self._plugin_manager: + if not request_id: + request_id = uuid.uuid4().hex + global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) + try: + pre_result, context_table = await self._plugin_manager.prompt_pre_fetch(payload=PromptPrehookPayload(name, arguments), global_context=global_context, local_contexts=None) + + if not pre_result.continue_processing: + # Plugin blocked the request + if pre_result.violation: + violation_desc = pre_result.violation.description + plugin_name = pre_result.violation.plugin_name + violation_code = pre_result.violation.violation_code + raise PromptError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} {violation_desc}") + raise PromptError("Pre prompting fetch blocked by plugin") + + # Use modified payload if provided + if pre_result.modified_payload: + payload = pre_result.modified_payload + name = payload.name + arguments = payload.args + except Exception as e: + logger.error(f"Error in pre-prompt fetch plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise + # Find prompt prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(DbPrompt.is_active)).scalar_one_or_none() @@ -387,7 +433,7 @@ async def get_prompt(self, db: Session, name: str, arguments: Optional[Dict[str, raise PromptNotFoundError(f"Prompt not found: {name}") if not arguments: - return PromptResult( + result = PromptResult( messages=[ Message( role=Role.USER, @@ -401,10 +447,30 @@ async def get_prompt(self, db: Session, name: str, arguments: Optional[Dict[str, prompt.validate_arguments(arguments) rendered = self._render_template(prompt.template, arguments) messages = self._parse_messages(rendered) - return PromptResult(messages=messages, description=prompt.description) + result = PromptResult(messages=messages, description=prompt.description) except Exception as e: raise PromptError(f"Failed to process prompt: {str(e)}") + if self._plugin_manager: + try: + post_result, _ = await self._plugin_manager.prompt_post_fetch(payload=PromptPosthookPayload(name=name, result=result), global_context=global_context, local_contexts=context_table) + if not post_result.continue_processing: + # Plugin blocked the request + if post_result.violation: + violation_desc = post_result.violation.description + plugin_name = post_result.violation.plugin_name + violation_code = post_result.violation.violation_code + raise PromptError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} {violation_desc}") + raise PromptError("Post prompting fetch blocked by plugin") + # Use modified payload if provided + return post_result.modified_payload.result if post_result.modified_payload else result + except Exception as e: + logger.error(f"Error in post-prompt fetch plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise + return result + async def update_prompt(self, db: Session, name: str, prompt_update: PromptUpdate) -> PromptRead: """ Update a prompt template. diff --git a/plugins/regex/search_replace.py b/plugins/regex/search_replace.py index 3a4bfcad7..00e2db559 100644 --- a/plugins/regex/search_replace.py +++ b/plugins/regex/search_replace.py @@ -80,5 +80,5 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi pattern[1], message.content.text ) - payload.result.messages[index] = value + payload.result.messages[index].content.text = value return PromptPosthookResult(modified_payload=payload) From 4e7347c0f38cf0e1a805cc415550424ad9ab32f1 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Mon, 4 Aug 2025 09:50:09 -0600 Subject: [PATCH 04/20] fix(plugins): renamed types.py to plugin_types.py due to conflict in pytest Signed-off-by: Teryl Taylor --- mcpgateway/plugins/framework/base.py | 2 +- mcpgateway/plugins/framework/manager.py | 4 ++-- mcpgateway/plugins/framework/{types.py => plugin_types.py} | 0 mcpgateway/plugins/framework/utils.py | 2 +- plugins/regex/search_replace.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename mcpgateway/plugins/framework/{types.py => plugin_types.py} (100%) diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index d715ae926..10e1e8161 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -21,7 +21,7 @@ # First-Party from mcpgateway.plugins.framework.models import HookType, PluginCondition, PluginConfig, PluginMode -from mcpgateway.plugins.framework.types import ( +from mcpgateway.plugins.framework.plugin_types import ( PluginContext, PromptPosthookPayload, PromptPosthookResult, diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 0a6d978a2..a74eb7283 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -17,8 +17,7 @@ from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.models import Config, HookType, PluginCondition, PluginMode -from mcpgateway.plugins.framework.registry import PluginInstanceRegistry -from mcpgateway.plugins.framework.types import ( +from mcpgateway.plugins.framework.plugin_types import ( GlobalContext, PluginContext, PluginContextTable, @@ -28,6 +27,7 @@ PromptPrehookPayload, PromptPrehookResult, ) +from mcpgateway.plugins.framework.registry import PluginInstanceRegistry from mcpgateway.plugins.framework.utils import post_prompt_matches, pre_prompt_matches logger = logging.getLogger(__name__) diff --git a/mcpgateway/plugins/framework/types.py b/mcpgateway/plugins/framework/plugin_types.py similarity index 100% rename from mcpgateway/plugins/framework/types.py rename to mcpgateway/plugins/framework/plugin_types.py diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 597d1a7d0..7ac7da8e7 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -16,7 +16,7 @@ # First-Party from mcpgateway.plugins.framework.models import PluginCondition -from mcpgateway.plugins.framework.types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload @cache # noqa diff --git a/plugins/regex/search_replace.py b/plugins/regex/search_replace.py index 00e2db559..f04974396 100644 --- a/plugins/regex/search_replace.py +++ b/plugins/regex/search_replace.py @@ -16,7 +16,7 @@ # First-Party from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.models import PluginConfig -from mcpgateway.plugins.framework.types import PluginContext, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, PromptPrehookResult +from mcpgateway.plugins.framework.plugin_types import PluginContext, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, PromptPrehookResult class SearchReplace(BaseModel): From d71c0a3b29f05c8b9ca4d01eba75201d86faa177 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Mon, 4 Aug 2025 10:15:21 -0600 Subject: [PATCH 05/20] fix(plugins): fixed renamed plugin_types module in prompt_service.py Signed-off-by: Teryl Taylor --- mcpgateway/services/prompt_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index bb8af58db..a991600e2 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -34,7 +34,7 @@ from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate logger = logging.getLogger(__name__) From 287d1d987af1f3f89c991e577328ec9c4eaed3df Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Mon, 4 Aug 2025 22:43:57 -0400 Subject: [PATCH 06/20] feat: add example filter plugin Signed-off-by: Frederico Araujo --- mcpgateway/config.py | 4 +- mcpgateway/main.py | 4 +- mcpgateway/plugins/framework/base.py | 12 ++++-- mcpgateway/plugins/framework/manager.py | 3 ++ mcpgateway/plugins/framework/models.py | 18 ++++---- mcpgateway/services/prompt_service.py | 18 +++++--- plugins/config.yaml | 20 ++++++++- plugins/filter/deny.py | 57 +++++++++++++++++++++++++ plugins/filter/plugin-manifest.yaml | 6 +++ plugins/regex/search_replace.py | 2 +- 10 files changed, 121 insertions(+), 23 deletions(-) create mode 100644 plugins/filter/deny.py create mode 100644 plugins/filter/plugin-manifest.yaml diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 28d4fd8df..db80a501e 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -501,7 +501,9 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + validation_dangerous_html_pattern: str = ( + r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + ) validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/main.py b/mcpgateway/main.py index a2bf23cce..f82921180 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -1716,8 +1716,8 @@ async def get_prompt( PromptExecuteArgs(args=args) return await prompt_service.get_prompt(db, name, args) except Exception as ex: - logger.error(f"Error retrieving prompt {name}: {ex}") - if isinstance(ex, ValueError): + logger.error(f"Could not retrieve prompt {name}: {ex}") + if isinstance(ex, ValueError) or isinstance(ex, PromptError): return JSONResponse(content={"message": "Prompt execution arguments contains HTML tags that may cause security issues"}, status_code=422) diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 10e1e8161..89fd050cd 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -105,9 +105,11 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC Raises: NotImplementedError: needs to be implemented by sub class. """ - raise NotImplementedError(f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} + raise NotImplementedError( + f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} of plugin type {type(self)} - """) + """ + ) async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: """Plugin hook run after a prompt is rendered. @@ -119,9 +121,11 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi Raises: NotImplementedError: needs to be implemented by sub class. """ - raise NotImplementedError(f"""'prompt_post_fetch' not implemented for plugin {self._config.name} + raise NotImplementedError( + f"""'prompt_post_fetch' not implemented for plugin {self._config.name} of plugin type {type(self)} - """) + """ + ) async def shutdown(self) -> None: """Plugin cleanup code.""" diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index a74eb7283..689abb50b 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -83,6 +83,9 @@ async def execute( if result.modified_payload is not None: current_payload = result.modified_payload + if result.violation: + result.violation._plugin_name = pluginref.plugin.name + if not result.continue_processing: # Check execution mode if pluginref.plugin.mode == PluginMode.ENFORCE: diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index d174b80d7..c100244bb 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -14,7 +14,7 @@ from typing import Any, Optional # Third-Party -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr class HookType(str, Enum): @@ -158,19 +158,21 @@ class PluginManifest(BaseModel): class PluginViolation(BaseModel): - """A plugin filter violation. + """A plugin violation, used to denote policy violations. Attributes: - plugin_name (str): The name of the plugin. - description (str): the violation in text. - violation_code (str): a violation code. - details: (dict[str, Any]) + reason (str): the reason for the violation. + description (str): a longer description of the violation. + code (str): a violation code. + details: (dict[str, Any]): additional violation details. + _plugin_name (str): the plugin name, private attribute set by the plugin manager. """ - plugin_name: str + reason: str description: str - violation_code: str + code: str details: dict[str, Any] + _plugin_name: PrivateAttr = "" class PluginSettings(BaseModel): diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index a991600e2..81a7c9b77 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -405,10 +405,11 @@ async def get_prompt( if not pre_result.continue_processing: # Plugin blocked the request if pre_result.violation: + plugin_name = pre_result.violation._plugin_name + violation_reason = pre_result.violation.reason violation_desc = pre_result.violation.description - plugin_name = pre_result.violation.plugin_name - violation_code = pre_result.violation.violation_code - raise PromptError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} {violation_desc}") + violation_code = pre_result.violation.code + raise PromptError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})") raise PromptError("Pre prompting fetch blocked by plugin") # Use modified payload if provided @@ -416,6 +417,8 @@ async def get_prompt( payload = pre_result.modified_payload name = payload.name arguments = payload.args + except PromptError: + raise except Exception as e: logger.error(f"Error in pre-prompt fetch plugin hook: {e}") # Only fail if configured to do so @@ -457,13 +460,16 @@ async def get_prompt( if not post_result.continue_processing: # Plugin blocked the request if post_result.violation: + plugin_name = post_result.violation._plugin_name + violation_reason = post_result.violation.reason violation_desc = post_result.violation.description - plugin_name = post_result.violation.plugin_name - violation_code = post_result.violation.violation_code - raise PromptError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} {violation_desc}") + violation_code = post_result.violation.code + raise PromptError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})") raise PromptError("Post prompting fetch blocked by plugin") # Use modified payload if provided return post_result.modified_payload.result if post_result.modified_payload else result + except PromptError: + raise except Exception as e: logger.error(f"Error in post-prompt fetch plugin hook: {e}") # Only fail if configured to do so diff --git a/plugins/config.yaml b/plugins/config.yaml index bbf704271..150b593ca 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -22,7 +22,25 @@ plugins: replace: crud - search: crud replace: yikes - + - name: "DenyListPlugin" + kind: "plugins.filter.deny.DenyListPlugin" + description: "A plugin that implements a deny list filter." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch"] + tags: ["plugin", "filter", "denylist", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 100 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - innovative + - groundbreaking + - revolutionary # Plugin directories to scan plugin_dirs: diff --git a/plugins/filter/deny.py b/plugins/filter/deny.py new file mode 100644 index 000000000..e890a17bb --- /dev/null +++ b/plugins/filter/deny.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +"""Simple example plugin for searching and replacing text. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Fred Araujo + +This module loads configurations for plugins. +""" +# Standard +import re + +# Third-Party +from pydantic import BaseModel + +# First-Party +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig, PluginViolation +from mcpgateway.plugins.framework.plugin_types import PluginContext, PromptPrehookPayload, PromptPrehookResult + + +class DenyListConfig(BaseModel): + words: list[str] + + +class DenyListPlugin(Plugin): + """Example deny list plugin.""" + def __init__(self, config: PluginConfig): + super().__init__(config) + self._dconfig = DenyListConfig.model_validate(self._config.config) + self._deny_list = [] + for word in self._dconfig.words: + self._deny_list.append(word) + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """The plugin hook run before a prompt is retrieved and rendered. + + Args: + payload: The prompt payload to be analyzed. + context: contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + if payload.args: + for key in payload.args: + if any(word in payload.args[key] for word in self._deny_list): + violation = PluginViolation( + reason="Prompt not allowed", + description="A deny word was found in the prompt", + code="deny", + details={}, + ) + return PromptPrehookResult(modified_payload=payload, violation=violation, continue_processing=False) + return PromptPrehookResult(modified_payload=payload) + + diff --git a/plugins/filter/plugin-manifest.yaml b/plugins/filter/plugin-manifest.yaml new file mode 100644 index 000000000..d6c8ae801 --- /dev/null +++ b/plugins/filter/plugin-manifest.yaml @@ -0,0 +1,6 @@ +description: "Deny list plugin manifest." +author: "MCP Context Forge Team" +version: "0.1" +available_hooks: + - "prompt_pre_hook" +default_configs: diff --git a/plugins/regex/search_replace.py b/plugins/regex/search_replace.py index f04974396..94c4dd5b9 100644 --- a/plugins/regex/search_replace.py +++ b/plugins/regex/search_replace.py @@ -29,7 +29,7 @@ class SearchReplaceConfig(BaseModel): class SearchReplacePlugin(Plugin): - """Example search replace plugin""" + """Example search replace plugin.""" def __init__(self, config: PluginConfig): super().__init__(config) self._srconfig = SearchReplaceConfig.model_validate(self._config.config) From 33189592457b9a13f1d6fe9bae702302324568ba Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Mon, 4 Aug 2025 23:29:35 -0400 Subject: [PATCH 07/20] feat: add plugin violation error object Signed-off-by: Frederico Araujo --- mcpgateway/main.py | 4 +++- mcpgateway/plugins/__init__.py | 26 ++++++++++++++++++++++++++ mcpgateway/plugins/framework/models.py | 14 ++++++++++++++ mcpgateway/services/prompt_service.py | 15 +++++++-------- 4 files changed, 50 insertions(+), 9 deletions(-) create mode 100644 mcpgateway/plugins/__init__.py diff --git a/mcpgateway/main.py b/mcpgateway/main.py index f82921180..e767649fd 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -75,7 +75,7 @@ ResourceContent, Root, ) -from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins import PluginManager, PluginViolationError from mcpgateway.schemas import ( GatewayCreate, GatewayRead, @@ -1719,6 +1719,8 @@ async def get_prompt( logger.error(f"Could not retrieve prompt {name}: {ex}") if isinstance(ex, ValueError) or isinstance(ex, PromptError): return JSONResponse(content={"message": "Prompt execution arguments contains HTML tags that may cause security issues"}, status_code=422) + if isinstance(ex, PluginViolationError): + return JSONResponse(content={"message": "Prompt execution arguments contains HTML tags that may cause security issues", "details": ex.message}, status_code=422) @prompt_router.get("/{name}") diff --git a/mcpgateway/plugins/__init__.py b/mcpgateway/plugins/__init__.py new file mode 100644 index 000000000..0924ce5b2 --- /dev/null +++ b/mcpgateway/plugins/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +"""Services Package. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Fred Araujo + +Exposes core MCP Gateway plugin components: +- Context +- Manager +- Payloads +- Models +""" + +from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.models import PluginViolation, PluginViolationError +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload + +__all__ = [ + "GlobalContext", + "PluginManager", + "PluginViolation", + "PluginViolationError", + "PromptPosthookPayload", + "PromptPrehookPayload", +] diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index c100244bb..6c1158348 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -175,6 +175,20 @@ class PluginViolation(BaseModel): _plugin_name: PrivateAttr = "" +class PluginViolationError(Exception): + """A plugin violation error. + + Attributes: + violation (PluginViolation): the plugin violation. + message (str): the plugin violation reason. + """ + + def __init__(self, message: str, violation: PluginViolation | None = None): + self.message = message + self.violation = violation + super().__init__(self.message) + + class PluginSettings(BaseModel): """Global plugin settings. diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 81a7c9b77..374caf191 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -33,8 +33,7 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins import GlobalContext, PluginManager, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate logger = logging.getLogger(__name__) @@ -409,15 +408,15 @@ async def get_prompt( violation_reason = pre_result.violation.reason violation_desc = pre_result.violation.description violation_code = pre_result.violation.code - raise PromptError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})") - raise PromptError("Pre prompting fetch blocked by plugin") + raise PluginViolationError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", pre_result.violation) + raise PluginViolationError("Pre prompting fetch blocked by plugin") # Use modified payload if provided if pre_result.modified_payload: payload = pre_result.modified_payload name = payload.name arguments = payload.args - except PromptError: + except PluginViolationError: raise except Exception as e: logger.error(f"Error in pre-prompt fetch plugin hook: {e}") @@ -464,11 +463,11 @@ async def get_prompt( violation_reason = post_result.violation.reason violation_desc = post_result.violation.description violation_code = post_result.violation.code - raise PromptError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})") - raise PromptError("Post prompting fetch blocked by plugin") + raise PluginViolationError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", post_result.violation) + raise PluginViolationError("Post prompting fetch blocked by plugin") # Use modified payload if provided return post_result.modified_payload.result if post_result.modified_payload else result - except PromptError: + except PluginViolationError: raise except Exception as e: logger.error(f"Error in post-prompt fetch plugin hook: {e}") From a1a140a7c5e73fffb32a1154ee9b90cb1b92d66f Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 07:57:47 -0600 Subject: [PATCH 08/20] feat: added unit tests Signed-off-by: Teryl Taylor --- mcpgateway/plugins/framework/loader/config.py | 50 +-------- mcpgateway/plugins/framework/loader/plugin.py | 5 + mcpgateway/plugins/framework/manager.py | 9 +- mcpgateway/plugins/framework/models.py | 2 +- mcpgateway/plugins/framework/registry.py | 11 ++ .../configs/invalid_single_plugin.yaml | 37 +++++++ .../configs/valid_multiple_plugins.yaml | 57 ++++++++++ .../fixtures/configs/valid_no_plugin.yaml | 16 +++ .../fixtures/configs/valid_single_plugin.yaml | 37 +++++++ .../framework/loader/test_plugin_loader.py | 76 +++++++++++++ .../plugins/framework/test_manager.py | 102 ++++++++++++++++++ .../plugins/framework/test_registry.py | 27 +++++ .../plugins/framework/test_utils.py | 47 ++++++++ 13 files changed, 419 insertions(+), 57 deletions(-) create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/invalid_single_plugin.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml create mode 100644 tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py create mode 100644 tests/unit/mcpgateway/plugins/framework/test_manager.py create mode 100644 tests/unit/mcpgateway/plugins/framework/test_registry.py create mode 100644 tests/unit/mcpgateway/plugins/framework/test_utils.py diff --git a/mcpgateway/plugins/framework/loader/config.py b/mcpgateway/plugins/framework/loader/config.py index d314c2e3c..d78f3f481 100644 --- a/mcpgateway/plugins/framework/loader/config.py +++ b/mcpgateway/plugins/framework/loader/config.py @@ -16,7 +16,7 @@ import yaml # First-Party -from mcpgateway.plugins.framework.models import Config, PluginConfig, PluginManifest +from mcpgateway.plugins.framework.models import Config class ConfigLoader: @@ -42,51 +42,3 @@ def load_config(config: str, use_jinja: bool = True) -> Config: rendered_template = template config_data = yaml.safe_load(rendered_template) return Config(**config_data) - - @staticmethod - def dump_config(path: str, config: Config) -> None: - """Dump plugin configuration to a file. - - Args: - path: configuration file path - config: the plugin configuration path - """ - with open(os.path.normpath(path), "w", encoding="utf-8") as file: - yaml.safe_dump(config.model_dump(exclude_none=True), file) - - @staticmethod - def load_plugin_config(config: str) -> PluginConfig: - """Load a plugin configuration from a file path. - - This function autoescapes curly brackets in the 'instruction' - and 'examples' keys under the config attribute. - - Args: - config: the plugin configuration path - - Returns: - The plugin configuration object - """ - with open(os.path.normpath(config), "r", encoding="utf8") as file: - template = file.read() - jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) - rendered_template = jinja_env.from_string(template).render(env=os.environ) - config_data = yaml.safe_load(rendered_template) - return PluginConfig(**config_data) - - @staticmethod - def load_plugin_manifest(manifest: str) -> PluginManifest: - """Load a plugin manifest from a file path. - - Args: - manifest: the plugin manifest path - - Returns: - The plugin manifest object - """ - with open(os.path.normpath(manifest), "r", encoding="utf8") as file: - template = file.read() - jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) - rendered_template = jinja_env.from_string(template).render(env=os.environ) - config_data = yaml.safe_load(rendered_template) - return PluginManifest(**config_data) diff --git a/mcpgateway/plugins/framework/loader/plugin.py b/mcpgateway/plugins/framework/loader/plugin.py index 2be7e075a..2df9b0a84 100644 --- a/mcpgateway/plugins/framework/loader/plugin.py +++ b/mcpgateway/plugins/framework/loader/plugin.py @@ -73,3 +73,8 @@ async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | No if plugin_type: return plugin_type(config) return None + + async def shutdown(self) -> None: + """Shutdown and cleanup plugin loader.""" + if self._plugin_types: + self._plugin_types.clear() diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 689abb50b..5ed5cddec 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -181,7 +181,7 @@ async def initialize(self) -> None: if self._initialized: return - plugins = self._config.plugins if self._config else [] + plugins = self._config.plugins if self._config and self._config.plugins else [] for plugin_config in plugins: if plugin_config.mode != PluginMode.DISABLED: @@ -195,12 +195,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """Shutdown all plugins.""" - for plugin_ref in self._registry.get_all_plugins(): - try: - await plugin_ref.plugin.shutdown() - except Exception as e: - logger.error(f"Error shutting down plugin {plugin_ref.plugin.name}: {e}") - + await self._registry.shutdown() self._initialized = False async def prompt_pre_fetch( diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 6c1158348..2c6baa643 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -216,6 +216,6 @@ class Config(BaseModel): plugin_settings: global settings for plugins. """ - plugins: list[PluginConfig] = [] + plugins: Optional[list[PluginConfig]] = [] plugin_dirs: list[str] = [] plugin_settings: PluginSettings diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index 7c2a5dce0..d9169d70a 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -114,3 +114,14 @@ def plugin_count(self) -> int: The number of plugins registered. """ return len(self._plugins) + + async def shutdown(self) -> None: + """Shutdown all plugins.""" + for plugin_ref in self._plugins.values(): + try: + await plugin_ref.plugin.shutdown() + except Exception as e: + logger.error(f"Error shutting down plugin {plugin_ref.plugin.name}: {e}") + self._plugins.clear() + self._hooks.clear() + self._priority_cache.clear() diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/invalid_single_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/invalid_single_plugin.yaml new file mode 100644 index 000000000..fdcfddfde --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/invalid_single_plugin.yaml @@ -0,0 +1,37 @@ +plugins: + # Self-contained Search Replace Plugin + - name: "FakePlugin" + kind: "some.fake.nonexistentPlugin" + description: "A plugin for finding and replacing words." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: crap + replace: crud + - search: crud + replace: yikes + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml new file mode 100644 index 000000000..7258c7cb4 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml @@ -0,0 +1,57 @@ +plugins: + - name: "SynonymsPlugin" + kind: "plugins.regex.search_replace.SearchReplacePlugin" + description: "A plugin for finding and replacing synonyms." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 149 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: happy + replace: gleeful + - search: sad + replace: sullen + # Self-contained Search Replace Plugin + - name: "ReplaceBadWordsPlugin" + kind: "plugins.regex.search_replace.SearchReplacePlugin" + description: "A plugin for finding and replacing words." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: crap + replace: crud + - search: crud + replace: yikes + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 \ No newline at end of file diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml new file mode 100644 index 000000000..a56fe3f75 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml @@ -0,0 +1,16 @@ +plugins: + # Self-contained Search Replace Plugin + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml new file mode 100644 index 000000000..5646b7ac5 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml @@ -0,0 +1,37 @@ +plugins: + # Self-contained Search Replace Plugin + - name: "ReplaceBadWordsPlugin" + kind: "plugins.regex.search_replace.SearchReplacePlugin" + description: "A plugin for finding and replacing words." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: crap + replace: crud + - search: crud + replace: yikes + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py new file mode 100644 index 000000000..53a0313b2 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -0,0 +1,76 @@ + +import pytest + +from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework.loader.config import ConfigLoader +from mcpgateway.plugins.framework.loader.plugin import PluginLoader +from mcpgateway.plugins.framework.models import PluginMode +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PluginContext, PromptPosthookPayload, PromptPrehookPayload +from plugins.regex.search_replace import SearchReplaceConfig, SearchReplacePlugin + + +def test_config_loader_load(): + """pytest for testing the config loader.""" + config = ConfigLoader.load_config(config="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + assert config + assert len(config.plugins) == 1 + assert config.plugins[0].name == "ReplaceBadWordsPlugin" + assert config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert config.plugins[0].description == "A plugin for finding and replacing words." + assert config.plugins[0].version == "0.1" + assert config.plugins[0].author == "MCP Context Forge Team" + assert config.plugins[0].hooks[0] == "prompt_pre_fetch" + assert config.plugins[0].hooks[1] == "prompt_post_fetch" + assert config.plugins[0].config + srconfig = SearchReplaceConfig.model_validate(config.plugins[0].config) + assert len(srconfig.words) == 2 + assert srconfig.words[0].search == "crap" + assert srconfig.words[0].replace == "crud" + +@pytest.mark.asyncio +async def test_plugin_loader_load(): + """Load a plugin with the plugin loader.""" + config = ConfigLoader.load_config(config="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + loader = PluginLoader() + plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) + assert isinstance(plugin, SearchReplacePlugin) + assert plugin.name == "ReplaceBadWordsPlugin" + assert plugin.mode == PluginMode.ENFORCE + assert plugin.priority == 150 + assert "test_prompt" in plugin.conditions[0].prompts + assert plugin.hooks[0] == "prompt_pre_fetch" + assert plugin.hooks[1] == "prompt_post_fetch" + + context = PluginContext(GlobalContext(request_id="1", server_id="2")) + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "What a crapshow!"}) + result = await plugin.prompt_pre_fetch(prompt, context=context) + assert len(result.modified_payload.args) == 1 + assert result.modified_payload.args["user"] == "What a yikesshow!" + + message=Message(content=TextContent(type="text", text="What the crud?"), role=Role.USER) + prompt_result = PromptResult(messages=[message]) + + payload_result = PromptPosthookPayload(name="test_prompt", result=prompt_result) + + result = await plugin.prompt_post_fetch(payload_result, context) + assert len(result.modified_payload.result.messages) == 1 + assert result.modified_payload.result.messages[0].content.text == 'What the yikes?' + + await loader.shutdown() + +@pytest.mark.asyncio +async def test_plugin_loader_invalid_plugin_load(): + """Load an invalid plugin with the plugin loader.""" + config = ConfigLoader.load_config(config="./tests/unit/mcpgateway/plugins/fixtures/configs/invalid_single_plugin.yaml", use_jinja=False) + loader = PluginLoader() + with pytest.raises(ModuleNotFoundError): + plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) + + + + + + + + + diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py new file mode 100644 index 000000000..1e662e433 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -0,0 +1,102 @@ +import pytest + +from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload +from plugins.regex.search_replace import SearchReplaceConfig + + +@pytest.mark.asyncio +async def test_manager_single_transformer_prompt_plugin(): + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + await manager.initialize() + assert manager.config.plugins[0].name == "ReplaceBadWordsPlugin" + assert manager.config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[0].description == "A plugin for finding and replacing words." + assert manager.config.plugins[0].version == "0.1" + assert manager.config.plugins[0].author == "MCP Context Forge Team" + assert manager.config.plugins[0].hooks[0] == "prompt_pre_fetch" + assert manager.config.plugins[0].hooks[1] == "prompt_post_fetch" + assert manager.config.plugins[0].config + srconfig = SearchReplaceConfig.model_validate(manager.config.plugins[0].config) + assert len(srconfig.words) == 2 + assert srconfig.words[0].search == "crap" + assert srconfig.words[0].replace == "crud" + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "What a crapshow!"}) + global_context = GlobalContext(request_id="1", server_id="2") + result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + assert len(result.modified_payload.args) == 1 + assert result.modified_payload.args["user"] == "What a yikesshow!" + + message=Message(content=TextContent(type="text", text=result.modified_payload.args["user"]), role=Role.USER) + + prompt_result = PromptResult(messages=[message]) + + payload_result = PromptPosthookPayload(name="test_prompt", result=prompt_result) + + result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + assert len(result.modified_payload.result.messages) == 1 + assert result.modified_payload.result.messages[0].content.text == 'What a yikesshow!' + await manager.shutdown() + +@pytest.mark.asyncio +async def test_manager_multiple_transformer_preprompt_plugin(): + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml") + await manager.initialize() + assert manager.initialized + assert manager.config.plugins[0].name == "SynonymsPlugin" + assert manager.config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[0].description == "A plugin for finding and replacing synonyms." + assert manager.config.plugins[0].version == "0.1" + assert manager.config.plugins[0].author == "MCP Context Forge Team" + assert manager.config.plugins[0].hooks[0] == "prompt_pre_fetch" + assert manager.config.plugins[0].hooks[1] == "prompt_post_fetch" + assert manager.config.plugins[0].config + srconfig = SearchReplaceConfig.model_validate(manager.config.plugins[0].config) + assert len(srconfig.words) == 2 + assert srconfig.words[0].search == "happy" + assert srconfig.words[0].replace == "gleeful" + assert manager.config.plugins[1].name == "ReplaceBadWordsPlugin" + assert manager.config.plugins[1].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[1].description == "A plugin for finding and replacing words." + assert manager.config.plugins[1].version == "0.1" + assert manager.config.plugins[1].author == "MCP Context Forge Team" + assert manager.config.plugins[1].hooks[0] == "prompt_pre_fetch" + assert manager.config.plugins[1].hooks[1] == "prompt_post_fetch" + assert manager.config.plugins[1].config + srconfig = SearchReplaceConfig.model_validate(manager.config.plugins[1].config) + assert srconfig.words[0].search == "crap" + assert srconfig.words[0].replace == "crud" + assert manager.plugin_count == 2 + + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "It's always happy at the crapshow."}) + global_context = GlobalContext(request_id="1", server_id="2") + result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + assert len(result.modified_payload.args) == 1 + assert result.modified_payload.args["user"] == "It's always gleeful at the yikesshow." + + message=Message(content=TextContent(type="text", text="It's sad at the crud bakery."), role=Role.USER) + + prompt_result = PromptResult(messages=[message]) + + payload_result = PromptPosthookPayload(name="test_prompt", result=prompt_result) + + result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + assert len(result.modified_payload.result.messages) == 1 + assert result.modified_payload.result.messages[0].content.text == "It's sullen at the yikes bakery." + await manager.shutdown() + +@pytest.mark.asyncio +async def test_manager_no_plugins(): + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() + assert manager.initialized + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "It's always happy at the crapshow."}) + global_context = GlobalContext(request_id="1", server_id="2") + result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + assert result.continue_processing + assert not result.modified_payload + + + + diff --git a/tests/unit/mcpgateway/plugins/framework/test_registry.py b/tests/unit/mcpgateway/plugins/framework/test_registry.py new file mode 100644 index 000000000..abe2c4352 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/test_registry.py @@ -0,0 +1,27 @@ +import pytest + +from mcpgateway.plugins.framework.loader.config import ConfigLoader +from mcpgateway.plugins.framework.loader.plugin import PluginLoader +from mcpgateway.plugins.framework.registry import PluginInstanceRegistry + +@pytest.mark.asyncio +async def test_registry_register(): + """Load a plugin with the plugin loader.""" + config = ConfigLoader.load_config(config="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + loader = PluginLoader() + plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) + registry = PluginInstanceRegistry() + registry.register(plugin) + + all_plugins = registry.get_all_plugins() + assert len(all_plugins) == 1 + assert registry.get_plugin("ReplaceBadWordsPlugin") + assert registry.get_plugin("SomeNonExistentPlugin") is None + + registry.unregister("ReplaceBadWordsPlugin") + assert registry.plugin_count == 0 + + registry.unregister("SomePluginThatDoesntExist") + + all_plugins = registry.get_all_plugins() + assert len(all_plugins) == 0 diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py new file mode 100644 index 000000000..acd302dd4 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -0,0 +1,47 @@ + +from mcpgateway.plugins.framework.utils import pre_prompt_matches, matches, post_prompt_matches +from mcpgateway.plugins.framework.models import PluginCondition +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload + + + +def test_server_ids(): + condition1 = PluginCondition(server_ids={"1", "2"}) + context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") + + payload1 = PromptPrehookPayload(name="test_prompt", args={}) + + assert matches(condition=condition1, context=context1) + assert pre_prompt_matches(payload1, [condition1], context1) + + context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") + assert not matches(condition=condition1, context=context2) + assert not pre_prompt_matches(payload1, conditions=[condition1], context=context2) + + condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) + + context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") + + assert matches(condition2, context2) + assert pre_prompt_matches(payload1, conditions=[condition2], context=context2) + + context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") + + assert not matches(condition2, context3) + assert not pre_prompt_matches(payload1, conditions=[condition2], context=context3) + + condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) + context4 = GlobalContext(user="blah", request_id="1") + + assert matches(condition4, context4) + assert pre_prompt_matches(payload1, conditions=[condition4], context=context4) + + context5 = GlobalContext(user="barney", request_id="1") + assert not matches(condition4, context5) + assert not pre_prompt_matches(payload1, conditions=[condition4], context=context5) + + condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) + + assert pre_prompt_matches(payload1, [condition5], context1) + condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) + assert not pre_prompt_matches(payload1, [condition6], context1) \ No newline at end of file From a5af9eb22b2605e544d26ee5798fc97770a8ac9e Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 08:34:39 -0600 Subject: [PATCH 09/20] fix(license): add licensing to unit test files and run lint. Signed-off-by: Teryl Taylor --- plugins/filter/deny.py | 2 -- .../configs/valid_multiple_plugins.yaml | 2 +- .../framework/loader/test_plugin_loader.py | 18 +++++++++--------- .../plugins/framework/test_manager.py | 15 ++++++++++----- .../plugins/framework/test_registry.py | 11 ++++++++++- .../mcpgateway/plugins/framework/test_utils.py | 10 +++++++++- 6 files changed, 39 insertions(+), 19 deletions(-) diff --git a/plugins/filter/deny.py b/plugins/filter/deny.py index e890a17bb..034d83c03 100644 --- a/plugins/filter/deny.py +++ b/plugins/filter/deny.py @@ -53,5 +53,3 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC ) return PromptPrehookResult(modified_payload=payload, violation=violation, continue_processing=False) return PromptPrehookResult(modified_payload=payload) - - diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml index 7258c7cb4..6a88124c5 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml @@ -54,4 +54,4 @@ plugin_settings: plugin_timeout: 30 fail_on_plugin_error: false enable_plugin_api: true - plugin_health_check_interval: 60 \ No newline at end of file + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index 53a0313b2..475c41d2f 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -1,3 +1,12 @@ +# -*- coding: utf-8 -*- +""" + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for config and plugin loaders. +""" import pytest @@ -65,12 +74,3 @@ async def test_plugin_loader_invalid_plugin_load(): loader = PluginLoader() with pytest.raises(ModuleNotFoundError): plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) - - - - - - - - - diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 1e662e433..6e4547919 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -1,3 +1,12 @@ +# -*- coding: utf-8 -*- +""" + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for plugin manager. +""" import pytest from mcpgateway.models import Message, PromptResult, Role, TextContent @@ -68,7 +77,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): assert srconfig.words[0].search == "crap" assert srconfig.words[0].replace == "crud" assert manager.plugin_count == 2 - + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) @@ -96,7 +105,3 @@ async def test_manager_no_plugins(): result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) assert result.continue_processing assert not result.modified_payload - - - - diff --git a/tests/unit/mcpgateway/plugins/framework/test_registry.py b/tests/unit/mcpgateway/plugins/framework/test_registry.py index abe2c4352..cebe8d330 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_registry.py +++ b/tests/unit/mcpgateway/plugins/framework/test_registry.py @@ -1,3 +1,12 @@ +# -*- coding: utf-8 -*- +""" + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for plugin registry. +""" import pytest from mcpgateway.plugins.framework.loader.config import ConfigLoader @@ -22,6 +31,6 @@ async def test_registry_register(): assert registry.plugin_count == 0 registry.unregister("SomePluginThatDoesntExist") - + all_plugins = registry.get_all_plugins() assert len(all_plugins) == 0 diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py index acd302dd4..830b1fe0a 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_utils.py +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -1,4 +1,12 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for utilities. +""" from mcpgateway.plugins.framework.utils import pre_prompt_matches, matches, post_prompt_matches from mcpgateway.plugins.framework.models import PluginCondition from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload @@ -44,4 +52,4 @@ def test_server_ids(): assert pre_prompt_matches(payload1, [condition5], context1) condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) - assert not pre_prompt_matches(payload1, [condition6], context1) \ No newline at end of file + assert not pre_prompt_matches(payload1, [condition6], context1) From 2dcb75b9600d31e34e7a2a842c487cf6caf20de0 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 10:56:53 -0600 Subject: [PATCH 10/20] fix(plugin): linting issues. Signed-off-by: Teryl Taylor --- mcpgateway/config.py | 4 +- mcpgateway/plugins/__init__.py | 4 +- mcpgateway/plugins/framework/manager.py | 2 +- mcpgateway/plugins/framework/models.py | 39 ++++++++++++-------- mcpgateway/plugins/framework/plugin_types.py | 20 ++++++++++ mcpgateway/services/prompt_service.py | 5 ++- mcpgateway/validators.py | 4 +- 7 files changed, 52 insertions(+), 26 deletions(-) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index db80a501e..28d4fd8df 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -501,9 +501,7 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = ( - r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" - ) + validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/plugins/__init__.py b/mcpgateway/plugins/__init__.py index 0924ce5b2..d93e84e99 100644 --- a/mcpgateway/plugins/__init__.py +++ b/mcpgateway/plugins/__init__.py @@ -13,8 +13,8 @@ """ from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.models import PluginViolation, PluginViolationError -from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework.models import PluginViolation +from mcpgateway.plugins.framework.plugin_types import GlobalContext, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload __all__ = [ "GlobalContext", diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 5ed5cddec..0c0f34b96 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -84,7 +84,7 @@ async def execute( current_payload = result.modified_payload if result.violation: - result.violation._plugin_name = pluginref.plugin.name + result.violation.plugin_name = pluginref.plugin.name if not result.continue_processing: # Check execution mode diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 2c6baa643..cec4235ed 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -172,21 +172,30 @@ class PluginViolation(BaseModel): description: str code: str details: dict[str, Any] - _plugin_name: PrivateAttr = "" - - -class PluginViolationError(Exception): - """A plugin violation error. - - Attributes: - violation (PluginViolation): the plugin violation. - message (str): the plugin violation reason. - """ - - def __init__(self, message: str, violation: PluginViolation | None = None): - self.message = message - self.violation = violation - super().__init__(self.message) + _plugin_name: str = PrivateAttr(default="") + + @property + def plugin_name(self) -> str: + """Getter for the plugin name attribute. + + Returns: + The plugin name associated with the violation. + """ + return self._plugin_name + + @plugin_name.setter + def plugin_name(self, name: str) -> None: + """Setter for the plugin_name attribute. + + Args: + name: the plugin name. + + Raises: + ValueError: if name is empty or not a string. + """ + if not isinstance(name, str) or not name.strip(): + raise ValueError("Name must be a non-empty string.") + self._plugin_name = name class PluginSettings(BaseModel): diff --git a/mcpgateway/plugins/framework/plugin_types.py b/mcpgateway/plugins/framework/plugin_types.py index 9ee70ebf8..d511997f2 100644 --- a/mcpgateway/plugins/framework/plugin_types.py +++ b/mcpgateway/plugins/framework/plugin_types.py @@ -134,3 +134,23 @@ async def cleanup(self) -> None: PluginContextTable = dict[str, PluginContext] + + +class PluginViolationError(Exception): + """A plugin violation error. + + Attributes: + violation (PluginViolation): the plugin violation. + message (str): the plugin violation reason. + """ + + def __init__(self, message: str, violation: PluginViolation | None = None): + """Initialize a plugin violation error. + + Args: + message: the reason for the violation error. + violation: the plugin violation object details. + """ + self.message = message + self.violation = violation + super().__init__(self.message) diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 374caf191..75f1f5354 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -378,6 +378,7 @@ async def get_prompt( Prompt result with rendered messages Raises: + PluginViolationError: If prompt violates a plugin policy PromptNotFoundError: If prompt not found PromptError: For other prompt errors @@ -404,7 +405,7 @@ async def get_prompt( if not pre_result.continue_processing: # Plugin blocked the request if pre_result.violation: - plugin_name = pre_result.violation._plugin_name + plugin_name = pre_result.violation.plugin_name violation_reason = pre_result.violation.reason violation_desc = pre_result.violation.description violation_code = pre_result.violation.code @@ -459,7 +460,7 @@ async def get_prompt( if not post_result.continue_processing: # Plugin blocked the request if post_result.violation: - plugin_name = post_result.violation._plugin_name + plugin_name = post_result.violation.plugin_name violation_reason = post_result.violation.reason violation_desc = post_result.violation.description violation_code = post_result.violation.code diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index a242a54b6..c7be0987a 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -52,9 +52,7 @@ class SecurityValidator: """Configurable validation with MCP-compliant limits""" # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = ( - settings.validation_dangerous_html_pattern - ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] From 18d2e4f9be71673a5bb195ce9decf5a65fa40f35 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 11:16:02 -0600 Subject: [PATCH 11/20] fix(plugins): added yaml dependency for plugins to pyproject.toml Signed-off-by: Teryl Taylor --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 1333a091a..29c605340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "pydantic>=2.11.7", "pydantic-settings>=2.10.1", "pyjwt>=2.10.1", + "PyYAML>=6.0.2", "sqlalchemy>=2.0.42", "sse-starlette>=3.0.2", "starlette>=0.47.2", From 5f19c922291a4c2a718ea5b18437f29f943400a1 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 15:03:02 -0600 Subject: [PATCH 12/20] test(plugin): added tests for filter plugins Signed-off-by: Teryl Taylor --- .../plugins/framework/test_manager.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 6e4547919..fa713b9ec 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -105,3 +105,30 @@ async def test_manager_no_plugins(): result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) assert result.continue_processing assert not result.modified_payload + await manager.shutdown() + +@pytest.mark.asyncio +async def test_manager_filter_plugins(): + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml") + await manager.initialize() + assert manager.initialized + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "innovative"}) + global_context = GlobalContext(request_id="1", server_id="2") + result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + assert not result.continue_processing + assert result.violation + await manager.shutdown() + +@pytest.mark.asyncio +async def test_manager_multi_filter_plugins(): + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + await manager.initialize() + assert manager.initialized + prompt = PromptPrehookPayload(name="test_prompt", args = {"user": "innovative crapshow."}) + global_context = GlobalContext(request_id="1", server_id="2") + result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + assert not result.continue_processing + assert result.violation + await manager.shutdown() + + From 2438109a3f3e4007b73539e951f9535f5bc00b8a Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 15:13:25 -0600 Subject: [PATCH 13/20] test(plugin): add missing config files for plugin tests Signed-off-by: Teryl Taylor --- .../valid_multiple_plugins_filter.yaml | 57 +++++++++++++++++++ .../configs/valid_single_filter_plugin.yaml | 36 ++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml new file mode 100644 index 000000000..bbc0fc6ad --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml @@ -0,0 +1,57 @@ +plugins: + # Self-contained Deny List Plugin + - name: "DenyListPlugin" + kind: "plugins.filter.deny.DenyListPlugin" + description: "A plugin that implements a deny list filter." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch"] + tags: ["plugin", "filter", "denylist", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 100 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - innovative + - groundbreaking + - revolutionary + # Self-contained Search Replace Plugin + - name: "ReplaceBadWordsPlugin" + kind: "plugins.regex.search_replace.SearchReplacePlugin" + description: "A plugin for finding and replacing words." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: crap + replace: crud + - search: crud + replace: yikes + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml new file mode 100644 index 000000000..ba63e818b --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml @@ -0,0 +1,36 @@ +plugins: + # Self-contained Deny List Plugin + - name: "DenyListPlugin" + kind: "plugins.filter.deny.DenyListPlugin" + description: "A plugin that implements a deny list filter." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch"] + tags: ["plugin", "filter", "denylist", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 100 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - innovative + - groundbreaking + - revolutionary + + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 From 3e3eb7b711b6688dfb2e9c4cb51ac90cde64a200 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Tue, 5 Aug 2025 23:18:02 +0100 Subject: [PATCH 14/20] Add PII filter plugin Signed-off-by: Mihai Criveti --- mcpgateway/config.py | 4 +- mcpgateway/validators.py | 4 +- plugins/config.yaml | 34 + plugins/pii_filter/README.md | 352 ++++++++++ plugins/pii_filter/__init__.py | 0 plugins/pii_filter/pii_filter.py | 646 ++++++++++++++++++ plugins/pii_filter/plugin-manifest.yaml | 11 + .../plugins/framework/test_manager.py | 2 - .../plugins/pii_filter/test_pii_filter.py | 503 ++++++++++++++ 9 files changed, 1552 insertions(+), 4 deletions(-) create mode 100644 plugins/pii_filter/README.md create mode 100644 plugins/pii_filter/__init__.py create mode 100644 plugins/pii_filter/pii_filter.py create mode 100644 plugins/pii_filter/plugin-manifest.yaml create mode 100644 tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 28d4fd8df..db80a501e 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -501,7 +501,9 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + validation_dangerous_html_pattern: str = ( + r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + ) validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index c7be0987a..a242a54b6 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -52,7 +52,9 @@ class SecurityValidator: """Configurable validation with MCP-compliant limits""" # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_HTML_PATTERN = ( + settings.validation_dangerous_html_pattern + ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] diff --git a/plugins/config.yaml b/plugins/config.yaml index 150b593ca..dd5f81391 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -1,6 +1,40 @@ # plugins/config.yaml - Main plugin configuration file plugins: + # PII Filter Plugin - Run first with highest priority for security + - name: "PIIFilterPlugin" + kind: "plugins.pii_filter.pii_filter.PIIFilterPlugin" + description: "Detects and masks Personally Identifiable Information" + version: "1.0" + author: "Mihai Criveti" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["security", "pii", "compliance", "filter", "gdpr", "hipaa"] + mode: "permissive" # enforce | permissive | disabled + priority: 50 # Lower number = higher priority (runs first) + conditions: + - prompts: [] # Empty list = apply to all prompts + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + # PII Detection Settings + detect_ssn: true + detect_credit_card: true + detect_email: true + detect_phone: true + detect_ip_address: false # Disabled for development + detect_aws_keys: true + detect_api_keys: true + # Masking Settings + default_mask_strategy: "partial" # redact | partial | hash | tokenize | remove + redaction_text: "[PII_REDACTED]" + # Behavior Settings + block_on_detection: false # Set to true for strict compliance + log_detections: true + include_detection_details: true + # Whitelist common test values + whitelist_patterns: + - "test@example.com" + - "555-555-5555" # Self-contained Search Replace Plugin - name: "ReplaceBadWordsPlugin" kind: "plugins.regex.search_replace.SearchReplacePlugin" diff --git a/plugins/pii_filter/README.md b/plugins/pii_filter/README.md new file mode 100644 index 000000000..c9c80119a --- /dev/null +++ b/plugins/pii_filter/README.md @@ -0,0 +1,352 @@ +# PII Filter Plugin for MCP Gateway + +> Author: Mihai Criveti + +A plugin for detecting and masking Personally Identifiable Information (PII) in MCP Gateway prompts and responses. + +## Features + +### PII Detection Types +- **Social Security Numbers (SSN)** - US format (123-45-6789 or 123456789) +- **Credit Card Numbers** - Major card formats with various separators +- **Email Addresses** - Standard email format validation +- **Phone Numbers** - US and international formats +- **IP Addresses** - IPv4 and IPv6 +- **Dates of Birth** - Various date formats with context +- **Passport Numbers** - International passport formats +- **Driver's License Numbers** - US state formats +- **Bank Account Numbers** - Including IBAN +- **Medical Record Numbers** - MRN formats +- **AWS Access Keys** - AKIA prefixed keys and secrets +- **API Keys** - Generic API key patterns +- **Custom Patterns** - Define your own PII patterns + +### Masking Strategies +- **REDACT** - Complete replacement with `[REDACTED]` or custom text +- **PARTIAL** - Show partial info (e.g., `***-**-1234` for SSN, `j***e@example.com` for email) +- **HASH** - Replace with hash value for consistency +- **TOKENIZE** - Replace with unique token for reversibility +- **REMOVE** - Complete removal of PII + +### Operating Modes +- **ENFORCE** - Block or mask PII (based on configuration) +- **PERMISSIVE** - Log detections but don't block +- **DISABLED** - Turn off the plugin + +## Installation + +1. Copy .env.example .env +2. Enable plugins in `.env` +3. Add the plugin configuration to `plugins/config.yaml`: + +```yaml +plugins: + - name: "PIIFilterPlugin" + kind: "plugins.pii_filter.pii_filter.PIIFilterPlugin" + description: "Detects and masks Personally Identifiable Information" + version: "1.0" + author: "Security Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["security", "pii", "compliance", "filter", "gdpr", "hipaa"] + mode: "enforce" # enforce | permissive | disabled + priority: 10 # Lower number = higher priority (runs first) + conditions: + - prompts: [] # Empty list = apply to all prompts + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + # PII Detection Settings + detect_ssn: true + detect_credit_card: true + detect_email: true + detect_phone: true + detect_ip_address: true + detect_aws_keys: true + detect_api_keys: true + # Masking Settings + default_mask_strategy: "partial" + redaction_text: "[PII_REDACTED]" + # Behavior Settings + block_on_detection: false + log_detections: true + include_detection_details: true + # Whitelist patterns + whitelist_patterns: + - "test@example.com" + - "555-555-5555" +``` + +## Configuration Examples + +### Development Environment (Permissive) +```yaml +config: + mode: "permissive" # Only log, don't block + detect_email: false # Allow emails in dev + detect_phone: false # Allow phones in dev + default_mask_strategy: "partial" # Show partial info for debugging + block_on_detection: false +``` + +### Production Environment (Strict Compliance) +```yaml +config: + mode: "enforce" + block_on_detection: true # Block any request with PII + default_mask_strategy: "redact" # Complete redaction + log_detections: true + detect_ssn: true + detect_credit_card: true + detect_email: true + # ... enable all detection types +``` + +### API Keys Only +```yaml +config: + detect_ssn: false + detect_credit_card: false + detect_email: false + detect_phone: false + detect_aws_keys: true # Only detect API keys + detect_api_keys: true + block_on_detection: true # Always block if keys detected + default_mask_strategy: "redact" +``` + +## Testing + +### Run All Tests +```bash +# Run all PII filter tests +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py -v + +# Run with coverage +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py --cov=plugins.pii_filter --cov-report=term-missing +``` + +### Run Specific Test Classes +```bash +# Test only the detector functionality +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py::TestPIIDetector -v + +# Test only the plugin integration +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py::TestPIIFilterPlugin -v +``` + +### Run Individual Tests +```bash +# Test SSN detection +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py::TestPIIDetector::test_ssn_detection -v + +# Test masking strategies +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py::TestPIIDetector::test_masking_strategies -v + +# Test blocking mode +pytest tests/unit/mcpgateway/plugins/pii_filter/test_pii_filter.py::TestPIIFilterPlugin::test_prompt_pre_fetch_blocking -v +``` + +### Manual Testing with the Gateway + +1. Enable the plugin in your `.env`: +```bash +PLUGINS_ENABLED=true +``` + +2. Start the gateway: +```bash +python -m mcpgateway.main +``` + +3. Test with curl: +```bash +# Test PII detection in prompt arguments +curl -X POST http://localhost:8000/prompts/test_prompt \ + -H "Content-Type: application/json" \ + -d '{ + "args": { + "user_input": "My SSN is 123-45-6789 and email is john@example.com" + } + }' + +# Response should have masked PII: +# "user_input": "My SSN is ***-**-6789 and email is j***n@example.com" +``` + +### Test Custom Patterns + +Add custom patterns in your config: +```yaml +config: + custom_patterns: + - type: "custom" + pattern: "\\bEMP\\d{6}\\b" + description: "Employee ID" + mask_strategy: "redact" + enabled: true +``` + +Test the custom pattern: +```python +from plugins.pii_filter.pii_filter import PIIFilterPlugin, PIIFilterConfig, PIIDetector + +config = PIIFilterConfig( + custom_patterns=[{ + "type": "custom", + "pattern": r"\bEMP\d{6}\b", + "description": "Employee ID", + "mask_strategy": "redact", + "enabled": True + }] +) +detector = PIIDetector(config) + +text = "Employee ID: EMP123456" +detections = detector.detect(text) +masked = detector.mask(text, detections) +print(masked) # Output: "Employee ID: [REDACTED]" +``` + +## Debugging + +### Enable Debug Logging +```python +import logging +logging.basicConfig(level=logging.DEBUG) + +# The plugin will log all PII detections +logger = logging.getLogger("plugins.pii_filter.pii_filter") +logger.setLevel(logging.DEBUG) +``` + +### Check Detection Results +```python +from plugins.pii_filter.pii_filter import PIIDetector, PIIFilterConfig + +config = PIIFilterConfig(detect_ssn=True, detect_email=True) +detector = PIIDetector(config) + +text = "SSN: 123-45-6789, Email: test@example.com" +detections = detector.detect(text) + +# Inspect what was detected +for pii_type, items in detections.items(): + print(f"Type: {pii_type}") + for item in items: + print(f" - Value: {item['value']}") + print(f" - Position: {item['start']}-{item['end']}") + print(f" - Strategy: {item['mask_strategy']}") +``` + +## Common Issues and Solutions + +### Issue: PII not being detected +**Solution**: Check that the specific detection type is enabled in config: +```yaml +config: + detect_ssn: true # Make sure this is true + detect_email: true +``` + +### Issue: False positives (detecting non-PII) +**Solution**: Use whitelist patterns: +```yaml +config: + whitelist_patterns: + - "test@example.com" + - "555-555-5555" + - "000-00-0000" +``` + +### Issue: Overlapping detections +**Solution**: The plugin automatically handles overlapping patterns by keeping only the first match. If you need different behavior, adjust pattern priorities or use custom patterns. + +### Issue: Plugin not running +**Solution**: Verify: +1. `PLUGINS_ENABLED=true` in `.env` +2. Plugin priority is set correctly (lower number = runs first) +3. Plugin mode is not set to "disabled" +4. Conditions match your prompts/servers + +## Performance Considerations + +- **Pattern Compilation**: Patterns are compiled once during initialization +- **Detection Speed**: O(n*m) where n = text length, m = number of patterns +- **Memory Usage**: Minimal - only stores compiled patterns and current detections +- **Caching**: No caching by default (stateless detection) + +## Security Best Practices + +1. **Production Settings**: + - Always use `mode: "enforce"` in production + - Enable `block_on_detection: true` for sensitive environments + - Use `default_mask_strategy: "redact"` for complete removal + +2. **Logging**: + - Enable `log_detections: true` for audit trails + - Monitor logs for PII detection patterns + - Never log the actual PII values + +3. **Testing**: + - Test with realistic data patterns + - Verify whitelist patterns don't expose real PII + - Regularly update patterns for new PII formats + + +## Sample Prompt + +Here's a prompt that trips the checks: + +```text +Personal Info: +SSN: 123-45-6789 or 987654321 +Email: john@example.com +Phone: (555) 123-4567 or +1-800-555-0199 +DOB: 01/15/1985 +``` + +## CURL Command to Test + +```bash +export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token -u admin --secret my-test-key) + +# Then test with a prompt containing various PII +curl -X GET "http://localhost:4444/prompts/test_prompt" \ + -H "Authorization: Bearer $MCPGATEWAY_BEARER_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "arguments": { + "user_input": "My SSN is 123-45-6789 and email is john@example.com. Credit card: 4111-1111-1111-1111, phone (555) 123-4567. Server IP: 192.168.1.1, AWS Key: AKIAIOSFODNN7EXAMPLE" + } + }' +``` + +## Contributing + +To add new PII detection patterns: + +1. Add the pattern to `_compile_patterns()` method: +```python +if self.config.detect_my_pattern: + patterns.append(PIIPattern( + type=PIIType.MY_PATTERN, + pattern=r'your-regex-here', + description="Description", + mask_strategy=MaskingStrategy.REDACT + )) +``` + +2. Add configuration option to `PIIFilterConfig`: +```python +detect_my_pattern: bool = Field(default=True, description="Detect my pattern") +``` + +3. Add tests to verify detection and masking + +## License + +Apache-2.0 + +## Support + +For issues or questions, please open an issue in the MCP Gateway repository. diff --git a/plugins/pii_filter/__init__.py b/plugins/pii_filter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/pii_filter/pii_filter.py b/plugins/pii_filter/pii_filter.py new file mode 100644 index 000000000..25d69653a --- /dev/null +++ b/plugins/pii_filter/pii_filter.py @@ -0,0 +1,646 @@ +# -*- coding: utf-8 -*- +"""PII Filter Plugin for MCP Gateway. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +This plugin detects and masks Personally Identifiable Information (PII) in prompts +and their responses, including SSNs, credit cards, emails, phone numbers, and more. +""" + +# Standard +import re +from enum import Enum +from typing import Optional, Pattern, Dict, List, Tuple +import logging + +# Third-Party +from pydantic import BaseModel, Field + +# First-Party +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig, PluginViolation +from mcpgateway.plugins.framework.plugin_types import ( + PluginContext, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, +) + +logger = logging.getLogger(__name__) + + +class PIIType(str, Enum): + """Types of PII that can be detected.""" + + SSN = "ssn" + CREDIT_CARD = "credit_card" + EMAIL = "email" + PHONE = "phone" + IP_ADDRESS = "ip_address" + DATE_OF_BIRTH = "date_of_birth" + PASSPORT = "passport" + DRIVER_LICENSE = "driver_license" + BANK_ACCOUNT = "bank_account" + MEDICAL_RECORD = "medical_record" + AWS_KEY = "aws_key" + API_KEY = "api_key" + CUSTOM = "custom" + + +class MaskingStrategy(str, Enum): + """Strategies for masking detected PII.""" + + REDACT = "redact" # Replace with [REDACTED] + PARTIAL = "partial" # Show partial info (e.g., ***-**-1234) + HASH = "hash" # Replace with hash + TOKENIZE = "tokenize" # Replace with token + REMOVE = "remove" # Remove entirely + + +class PIIPattern(BaseModel): + """Configuration for a PII pattern.""" + + type: PIIType + pattern: str + description: str + mask_strategy: MaskingStrategy = MaskingStrategy.REDACT + enabled: bool = True + + +class PIIFilterConfig(BaseModel): + """Configuration for the PII Filter plugin.""" + + # Enable/disable detection for specific PII types + detect_ssn: bool = Field(default=True, description="Detect Social Security Numbers") + detect_credit_card: bool = Field(default=True, description="Detect credit card numbers") + detect_email: bool = Field(default=True, description="Detect email addresses") + detect_phone: bool = Field(default=True, description="Detect phone numbers") + detect_ip_address: bool = Field(default=True, description="Detect IP addresses") + detect_date_of_birth: bool = Field(default=True, description="Detect dates of birth") + detect_passport: bool = Field(default=True, description="Detect passport numbers") + detect_driver_license: bool = Field(default=True, description="Detect driver's license numbers") + detect_bank_account: bool = Field(default=True, description="Detect bank account numbers") + detect_medical_record: bool = Field(default=True, description="Detect medical record numbers") + detect_aws_keys: bool = Field(default=True, description="Detect AWS access keys") + detect_api_keys: bool = Field(default=True, description="Detect generic API keys") + + # Masking configuration + default_mask_strategy: MaskingStrategy = Field( + default=MaskingStrategy.REDACT, + description="Default masking strategy" + ) + redaction_text: str = Field(default="[REDACTED]", description="Text to use for redaction") + + # Behavior configuration + block_on_detection: bool = Field( + default=False, + description="Block request if PII is detected" + ) + log_detections: bool = Field(default=True, description="Log PII detections") + include_detection_details: bool = Field( + default=True, + description="Include detection details in metadata" + ) + + # Custom patterns + custom_patterns: List[PIIPattern] = Field( + default_factory=list, + description="Custom PII patterns to detect" + ) + + # Whitelist configuration + whitelist_patterns: List[str] = Field( + default_factory=list, + description="Patterns to exclude from PII detection" + ) + + +class PIIDetector: + """Core PII detection logic.""" + + def __init__(self, config: PIIFilterConfig): + """Initialize the PII detector with configuration. + + Args: + config: PII filter configuration + """ + self.config = config + self.patterns: Dict[PIIType, List[Tuple[Pattern, MaskingStrategy]]] = {} + self._compile_patterns() + self._compile_whitelist() + + def _compile_patterns(self) -> None: + """Compile regex patterns for PII detection.""" + patterns = [] + + # Social Security Number patterns + if self.config.detect_ssn: + patterns.append(PIIPattern( + type=PIIType.SSN, + pattern=r'\b\d{3}-\d{2}-\d{4}\b|\b\d{9}\b', + description="US Social Security Number", + mask_strategy=MaskingStrategy.PARTIAL + )) + + # Credit Card patterns (basic validation for common formats) + if self.config.detect_credit_card: + patterns.append(PIIPattern( + type=PIIType.CREDIT_CARD, + pattern=r'\b(?:\d{4}[-\s]?){3}\d{4}\b', + description="Credit card number", + mask_strategy=MaskingStrategy.PARTIAL + )) + + # Email patterns + if self.config.detect_email: + patterns.append(PIIPattern( + type=PIIType.EMAIL, + pattern=r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', + description="Email address", + mask_strategy=MaskingStrategy.PARTIAL + )) + + # Phone number patterns (US and international) + if self.config.detect_phone: + patterns.extend([ + PIIPattern( + type=PIIType.PHONE, + pattern=r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b', + description="US phone number", + mask_strategy=MaskingStrategy.PARTIAL + ), + PIIPattern( + type=PIIType.PHONE, + pattern=r'\b\+?[1-9]\d{1,14}\b', + description="International phone number", + mask_strategy=MaskingStrategy.PARTIAL + ) + ]) + + # IP Address patterns (IPv4 and IPv6) + if self.config.detect_ip_address: + patterns.extend([ + PIIPattern( + type=PIIType.IP_ADDRESS, + pattern=r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b', + description="IPv4 address", + mask_strategy=MaskingStrategy.REDACT + ), + PIIPattern( + type=PIIType.IP_ADDRESS, + pattern=r'\b(?:[A-Fa-f0-9]{1,4}:){7}[A-Fa-f0-9]{1,4}\b', + description="IPv6 address", + mask_strategy=MaskingStrategy.REDACT + ) + ]) + + # Date of Birth patterns + if self.config.detect_date_of_birth: + patterns.extend([ + PIIPattern( + type=PIIType.DATE_OF_BIRTH, + pattern=r'\b(?:DOB|Date of Birth|Born|Birthday)[:\s]+\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b', + description="Date of birth with label", + mask_strategy=MaskingStrategy.REDACT + ), + PIIPattern( + type=PIIType.DATE_OF_BIRTH, + pattern=r'\b(?:0[1-9]|1[0-2])[-/](?:0[1-9]|[12]\d|3[01])[-/](?:19|20)\d{2}\b', + description="Date in MM/DD/YYYY format", + mask_strategy=MaskingStrategy.REDACT + ) + ]) + + # Passport patterns + if self.config.detect_passport: + patterns.append(PIIPattern( + type=PIIType.PASSPORT, + pattern=r'\b[A-Z]{1,2}\d{6,9}\b', + description="Passport number", + mask_strategy=MaskingStrategy.REDACT + )) + + # Driver's License patterns (US states) + if self.config.detect_driver_license: + patterns.append(PIIPattern( + type=PIIType.DRIVER_LICENSE, + pattern=r'\b(?:DL|License|Driver\'?s? License)[#:\s]+[A-Z0-9]{5,20}\b', + description="Driver's license number", + mask_strategy=MaskingStrategy.REDACT + )) + + # Bank Account patterns + if self.config.detect_bank_account: + patterns.extend([ + PIIPattern( + type=PIIType.BANK_ACCOUNT, + pattern=r'\b\d{8,17}\b', # Generic bank account + description="Bank account number", + mask_strategy=MaskingStrategy.REDACT + ), + PIIPattern( + type=PIIType.BANK_ACCOUNT, + pattern=r'\b[A-Z]{2}\d{2}[A-Z0-9]{4}\d{7}(?:\d{3})?\b', # IBAN + description="IBAN", + mask_strategy=MaskingStrategy.PARTIAL + ) + ]) + + # Medical Record patterns + if self.config.detect_medical_record: + patterns.append(PIIPattern( + type=PIIType.MEDICAL_RECORD, + pattern=r'\b(?:MRN|Medical Record)[#:\s]+[A-Z0-9]{6,12}\b', + description="Medical record number", + mask_strategy=MaskingStrategy.REDACT + )) + + # AWS Access Key patterns + if self.config.detect_aws_keys: + patterns.extend([ + PIIPattern( + type=PIIType.AWS_KEY, + pattern=r'\bAKIA[0-9A-Z]{16}\b', + description="AWS Access Key ID", + mask_strategy=MaskingStrategy.REDACT + ), + PIIPattern( + type=PIIType.AWS_KEY, + pattern=r'\b[A-Za-z0-9/+=]{40}\b', + description="AWS Secret Access Key", + mask_strategy=MaskingStrategy.REDACT + ) + ]) + + # Generic API Key patterns + if self.config.detect_api_keys: + patterns.append(PIIPattern( + type=PIIType.API_KEY, + pattern=r'\b(?:api[_-]?key|apikey|api_token|access[_-]?token)[:\s]+[\'"]?[A-Za-z0-9\-_]{20,}[\'"]?\b', + description="Generic API key", + mask_strategy=MaskingStrategy.REDACT + )) + + # Add custom patterns + patterns.extend(self.config.custom_patterns) + + # Compile patterns by type + for pattern_config in patterns: + if pattern_config.enabled: + compiled = re.compile(pattern_config.pattern, re.IGNORECASE) + if pattern_config.type not in self.patterns: + self.patterns[pattern_config.type] = [] + self.patterns[pattern_config.type].append( + (compiled, pattern_config.mask_strategy) + ) + + def _compile_whitelist(self) -> None: + """Compile whitelist patterns.""" + self.whitelist_patterns = [ + re.compile(pattern, re.IGNORECASE) + for pattern in self.config.whitelist_patterns + ] + + def _is_whitelisted(self, text: str, match_start: int, match_end: int) -> bool: + """Check if a matched pattern is whitelisted. + + Args: + text: The full text + match_start: Start position of the match + match_end: End position of the match + + Returns: + True if the match is whitelisted + """ + match_text = text[match_start:match_end] + for pattern in self.whitelist_patterns: + if pattern.search(match_text): + return True + return False + + def detect(self, text: str) -> Dict[PIIType, List[Dict]]: + """Detect PII in text. + + Args: + text: Text to scan for PII + + Returns: + Dictionary of detected PII by type + """ + detections = {} + + for pii_type, pattern_list in self.patterns.items(): + type_detections = [] + seen_ranges = [] # Track ranges we've already detected + + for pattern, mask_strategy in pattern_list: + for match in pattern.finditer(text): + if not self._is_whitelisted(text, match.start(), match.end()): + # Check if this overlaps with any existing detection + overlaps = False + for start, end in seen_ranges: + if (match.start() >= start and match.start() < end) or \ + (match.end() > start and match.end() <= end) or \ + (match.start() <= start and match.end() >= end): + overlaps = True + break + + if not overlaps: + type_detections.append({ + 'value': match.group(), + 'start': match.start(), + 'end': match.end(), + 'mask_strategy': mask_strategy + }) + seen_ranges.append((match.start(), match.end())) + + if type_detections: + detections[pii_type] = type_detections + + return detections + + def mask(self, text: str, detections: Dict[PIIType, List[Dict]]) -> str: + """Mask detected PII in text. + + Args: + text: Original text + detections: Dictionary of detected PII + + Returns: + Text with PII masked + """ + if not detections: + return text + + # Sort all detections by position (reverse order for replacement) + all_detections = [] + for pii_type, items in detections.items(): + for item in items: + item['type'] = pii_type + all_detections.append(item) + + all_detections.sort(key=lambda x: x['start'], reverse=True) + + # Apply masking + masked_text = text + for detection in all_detections: + strategy = detection.get('mask_strategy', self.config.default_mask_strategy) + masked_value = self._apply_mask( + detection['value'], + detection['type'], + strategy + ) + masked_text = ( + masked_text[:detection['start']] + + masked_value + + masked_text[detection['end']:] + ) + + return masked_text + + def _apply_mask(self, value: str, pii_type: PIIType, strategy: MaskingStrategy) -> str: + """Apply masking strategy to a value. + + Args: + value: Value to mask + pii_type: Type of PII + strategy: Masking strategy to apply + + Returns: + Masked value + """ + if strategy == MaskingStrategy.REDACT: + return self.config.redaction_text + + elif strategy == MaskingStrategy.PARTIAL: + # Show partial information based on type + if pii_type == PIIType.SSN: + if len(value) >= 4: + return f"***-**-{value[-4:]}" + return self.config.redaction_text + + elif pii_type == PIIType.CREDIT_CARD: + if len(value) >= 4: + return f"****-****-****-{value[-4:]}" + return self.config.redaction_text + + elif pii_type == PIIType.EMAIL: + parts = value.split('@') + if len(parts) == 2: + name = parts[0] + if len(name) > 2: + return f"{name[0]}***{name[-1]}@{parts[1]}" + return f"***@{parts[1]}" + return self.config.redaction_text + + elif pii_type == PIIType.PHONE: + if len(value) >= 4: + return f"***-***-{value[-4:]}" + return self.config.redaction_text + + else: + # For other types, show first and last characters + if len(value) > 2: + return f"{value[0]}{'*' * (len(value) - 2)}{value[-1]}" + return self.config.redaction_text + + elif strategy == MaskingStrategy.HASH: + import hashlib + return f"[HASH:{hashlib.sha256(value.encode()).hexdigest()[:8]}]" + + elif strategy == MaskingStrategy.TOKENIZE: + import uuid + # In production, you'd store the mapping + return f"[TOKEN:{uuid.uuid4().hex[:8]}]" + + elif strategy == MaskingStrategy.REMOVE: + return "" + + return self.config.redaction_text + + +class PIIFilterPlugin(Plugin): + """PII Filter plugin for detecting and masking sensitive information.""" + + def __init__(self, config: PluginConfig): + """Initialize the PII filter plugin. + + Args: + config: Plugin configuration + """ + super().__init__(config) + self.pii_config = PIIFilterConfig.model_validate(self._config.config) + self.detector = PIIDetector(self.pii_config) + self.detection_count = 0 + self.masked_count = 0 + + async def prompt_pre_fetch( + self, + payload: PromptPrehookPayload, + context: PluginContext + ) -> PromptPrehookResult: + """Process prompt before retrieval to detect and mask PII. + + Args: + payload: The prompt payload + context: Plugin context + + Returns: + Result with masked PII or violation if blocking + """ + if not payload.args: + return PromptPrehookResult() + + all_detections = {} + modified_args = {} + + # Process each argument + for key, value in payload.args.items(): + if isinstance(value, str): + detections = self.detector.detect(value) + + if detections: + all_detections[key] = detections + + if self.pii_config.log_detections: + logger.warning( + f"PII detected in prompt argument '{key}': " + f"{', '.join(detections.keys())}" + ) + + if self.pii_config.block_on_detection: + violation = PluginViolation( + reason="PII detected in prompt", + description=f"Sensitive information detected in argument '{key}'", + code="PII_DETECTED", + details={ + "field": key, + "types": list(detections.keys()), + "count": sum(len(items) for items in detections.values()) + } + ) + return PromptPrehookResult( + continue_processing=False, + violation=violation + ) + + # Mask the PII + masked_value = self.detector.mask(value, detections) + modified_args[key] = masked_value + self.masked_count += sum(len(items) for items in detections.values()) + else: + modified_args[key] = value + else: + modified_args[key] = value + + # Update context with detection metadata + if all_detections and self.pii_config.include_detection_details: + context.metadata["pii_detections"] = { + "pre_fetch": { + "detected": True, + "fields": list(all_detections.keys()), + "types": list(set( + pii_type + for field_detections in all_detections.values() + for pii_type in field_detections.keys() + )), + "total_count": sum( + len(items) + for field_detections in all_detections.values() + for items in field_detections.values() + ) + } + } + + # Return modified payload if PII was masked + if all_detections: + return PromptPrehookResult( + modified_payload=PromptPrehookPayload( + name=payload.name, + args=modified_args + ) + ) + + return PromptPrehookResult() + + async def prompt_post_fetch( + self, + payload: PromptPosthookPayload, + context: PluginContext + ) -> PromptPosthookResult: + """Process prompt after rendering to detect and mask PII in response. + + Args: + payload: The prompt result payload + context: Plugin context + + Returns: + Result with masked PII in messages + """ + if not payload.result.messages: + return PromptPosthookResult() + + modified = False + all_detections = {} + + # Process each message + for message in payload.result.messages: + if message.content and hasattr(message.content, 'text'): + text = message.content.text + detections = self.detector.detect(text) + + if detections: + all_detections[f"message_{message.role}"] = detections + + if self.pii_config.log_detections: + logger.warning( + f"PII detected in {message.role} message: " + f"{', '.join(detections.keys())}" + ) + + # Mask the PII + masked_text = self.detector.mask(text, detections) + message.content.text = masked_text + modified = True + self.masked_count += sum(len(items) for items in detections.values()) + + # Update context with post-fetch detection metadata + if all_detections and self.pii_config.include_detection_details: + if "pii_detections" not in context.metadata: + context.metadata["pii_detections"] = {} + + context.metadata["pii_detections"]["post_fetch"] = { + "detected": True, + "messages": list(all_detections.keys()), + "types": list(set( + pii_type + for msg_detections in all_detections.values() + for pii_type in msg_detections.keys() + )), + "total_count": sum( + len(items) + for msg_detections in all_detections.values() + for items in msg_detections.values() + ) + } + + # Add summary statistics + context.metadata["pii_filter_stats"] = { + "total_detections": self.detection_count, + "total_masked": self.masked_count + } + + if modified: + return PromptPosthookResult(modified_payload=payload) + + return PromptPosthookResult() + + async def shutdown(self) -> None: + """Cleanup when plugin shuts down.""" + logger.info( + f"PII Filter plugin shutting down. " + f"Total masked: {self.masked_count} items" + ) diff --git a/plugins/pii_filter/plugin-manifest.yaml b/plugins/pii_filter/plugin-manifest.yaml new file mode 100644 index 000000000..c765e0c29 --- /dev/null +++ b/plugins/pii_filter/plugin-manifest.yaml @@ -0,0 +1,11 @@ +description: "PII Filter plugin for detecting and masking sensitive information" +author: "Mihai Criveti" +version: "0.1.0" +available_hooks: + - "prompt_pre_fetch" + - "prompt_post_fetch" +default_configs: + detect_ssn: true + detect_credit_card: true + detect_email: true + default_mask_strategy: "partial" diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index fa713b9ec..f10f3841c 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -130,5 +130,3 @@ async def test_manager_multi_filter_plugins(): assert not result.continue_processing assert result.violation await manager.shutdown() - - diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py new file mode 100644 index 000000000..8af97171b --- /dev/null +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -0,0 +1,503 @@ +# -*- coding: utf-8 -*- +"""Unit tests for PII Filter Plugin. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti +""" + +import pytest +from typing import Dict, Any + +from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework.models import PluginConfig, PluginMode, HookType +from mcpgateway.plugins.framework.plugin_types import ( + GlobalContext, + PluginContext, + PromptPosthookPayload, + PromptPrehookPayload, +) + +# Import the PII Filter plugin +from plugins.pii_filter.pii_filter import ( + PIIFilterPlugin, + PIIFilterConfig, + PIIDetector, + PIIType, + MaskingStrategy, +) + + +class TestPIIDetector: + """Test the PII detection functionality.""" + + def test_ssn_detection(self): + """Test Social Security Number detection.""" + config = PIIFilterConfig(detect_ssn=True) + detector = PIIDetector(config) + + test_cases = [ + ("My SSN is 123-45-6789", True), + ("SSN: 123456789", True), + ("Number 123-45-6789 is sensitive", True), + ("Regular number 123456789", True), + ("No SSN here", False), + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.SSN in detections + else: + assert PIIType.SSN not in detections + + def test_credit_card_detection(self): + """Test credit card number detection.""" + config = PIIFilterConfig(detect_credit_card=True) + detector = PIIDetector(config) + + test_cases = [ + ("Card: 4111-1111-1111-1111", True), + ("4111111111111111", True), + ("4111 1111 1111 1111", True), + ("No card here", False), + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.CREDIT_CARD in detections + else: + assert PIIType.CREDIT_CARD not in detections + + def test_email_detection(self): + """Test email address detection.""" + config = PIIFilterConfig(detect_email=True) + detector = PIIDetector(config) + + test_cases = [ + ("Contact me at john.doe@example.com", True), + ("Email: user@test.co.uk", True), + ("admin+test@company.org", True), + ("No email here", False), + ("Not an @email", False), + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.EMAIL in detections + else: + assert PIIType.EMAIL not in detections + + def test_phone_detection(self): + """Test phone number detection.""" + config = PIIFilterConfig(detect_phone=True) + detector = PIIDetector(config) + + test_cases = [ + ("Call me at 555-123-4567", True), + ("Phone: (555) 123-4567", True), + ("+1 555 123 4567", True), + ("5551234567", True), + ("No phone here", False), + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.PHONE in detections + else: + assert PIIType.PHONE not in detections + + def test_ip_address_detection(self): + """Test IP address detection.""" + config = PIIFilterConfig(detect_ip_address=True) + detector = PIIDetector(config) + + test_cases = [ + ("Server IP: 192.168.1.1", True), + ("Connect to 10.0.0.1", True), + ("IPv4: 255.255.255.255", True), + ("No IP here", False), + ("999.999.999.999", False), # Invalid IP + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.IP_ADDRESS in detections + else: + assert PIIType.IP_ADDRESS not in detections + + def test_aws_key_detection(self): + """Test AWS key detection.""" + config = PIIFilterConfig(detect_aws_keys=True) + detector = PIIDetector(config) + + test_cases = [ + ("Access key: AKIAIOSFODNN7EXAMPLE", True), + ("AKIA1234567890123456", True), + ("No key here", False), + ] + + for text, should_detect in test_cases: + detections = detector.detect(text) + if should_detect: + assert PIIType.AWS_KEY in detections + else: + assert PIIType.AWS_KEY not in detections + + def test_whitelist_functionality(self): + """Test that whitelisted patterns are not detected.""" + config = PIIFilterConfig( + detect_email=True, + whitelist_patterns=["test@example.com", "admin@localhost"] + ) + detector = PIIDetector(config) + + # Whitelisted emails should not be detected + text = "Contact test@example.com or admin@localhost" + detections = detector.detect(text) + assert PIIType.EMAIL not in detections + + # Non-whitelisted email should be detected + text = "Contact real@email.com" + detections = detector.detect(text) + assert PIIType.EMAIL in detections + + def test_masking_strategies(self): + """Test different masking strategies.""" + config = PIIFilterConfig( + detect_ssn=True, + detect_phone=False, # Disable phone detection + detect_bank_account=False # Disable bank account detection + ) + detector = PIIDetector(config) + + # Test REDACT strategy (SSN uses PARTIAL by default in the pattern) + text = "SSN: 123-45-6789" + detections = detector.detect(text) + masked = detector.mask(text, detections) + assert "***-**-6789" in masked # SSN partial masking pattern + assert "123-45-6789" not in masked + + # Test PARTIAL strategy + config = PIIFilterConfig( + detect_email=True, + detect_ssn=False, # Disable SSN for email test + detect_phone=False, + detect_bank_account=False, + default_mask_strategy=MaskingStrategy.PARTIAL + ) + detector = PIIDetector(config) + text = "Email: john.doe@example.com" + detections = detector.detect(text) + masked = detector.mask(text, detections) + assert "@example.com" in masked + assert "john.doe" not in masked + + # Test REMOVE strategy + config = PIIFilterConfig( + detect_ssn=True, + detect_phone=False, # Disable phone detection + detect_bank_account=False, # Disable bank account detection + default_mask_strategy=MaskingStrategy.REMOVE + ) + detector = PIIDetector(config) + text = "SSN: 123-45-6789" + detections = detector.detect(text) + masked = detector.mask(text, detections) + assert "123-45-6789" not in masked + # The result should have the SSN masked + assert masked == "SSN: ***-**-6789" + + def test_multiple_pii_detection(self): + """Test detection of multiple PII types in one text.""" + config = PIIFilterConfig( + detect_ssn=True, + detect_email=True, + detect_phone=True + ) + detector = PIIDetector(config) + + text = "Contact John at john@example.com or 555-123-4567. SSN: 123-45-6789" + detections = detector.detect(text) + + assert PIIType.EMAIL in detections + assert PIIType.PHONE in detections + assert PIIType.SSN in detections + assert len(detections) == 3 + + +class TestPIIFilterPlugin: + """Test the PII Filter plugin integration.""" + + @pytest.fixture + def plugin_config(self) -> PluginConfig: + """Create a test plugin configuration.""" + return PluginConfig( + name="TestPIIFilter", + description="Test PII Filter", + author="Test", + kind="plugins.pii_filter.pii_filter.PIIFilterPlugin", + version="1.0", + hooks=[HookType.PROMPT_PRE_FETCH, HookType.PROMPT_POST_FETCH], + tags=["test", "pii"], + mode=PluginMode.ENFORCE, + priority=10, + config={ + "detect_ssn": True, + "detect_credit_card": True, + "detect_email": True, + "detect_phone": True, + "detect_ip_address": True, + "detect_aws_keys": True, + "default_mask_strategy": "partial", + "block_on_detection": False, + "log_detections": True, + "include_detection_details": True, + } + ) + + @pytest.mark.asyncio + async def test_prompt_pre_fetch_with_pii(self, plugin_config): + """Test pre-fetch hook with PII detection.""" + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-1")) + + # Create payload with PII + payload = PromptPrehookPayload( + name="test_prompt", + args={ + "user_input": "My email is john@example.com and SSN is 123-45-6789", + "safe_input": "This has no PII" + } + ) + + result = await plugin.prompt_pre_fetch(payload, context) + + # Check that PII was masked + assert result.modified_payload is not None + assert "john@example.com" not in result.modified_payload.args["user_input"] + assert "123-45-6789" not in result.modified_payload.args["user_input"] + assert result.modified_payload.args["safe_input"] == "This has no PII" + + # Check metadata + assert "pii_detections" in context.metadata + assert context.metadata["pii_detections"]["pre_fetch"]["detected"] + assert "user_input" in context.metadata["pii_detections"]["pre_fetch"]["fields"] + + @pytest.mark.asyncio + async def test_prompt_pre_fetch_blocking(self, plugin_config): + """Test that blocking mode prevents processing when PII is detected.""" + # Enable blocking + plugin_config.config["block_on_detection"] = True + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-2")) + + payload = PromptPrehookPayload( + name="test_prompt", + args={"input": "My SSN is 123-45-6789"} + ) + + result = await plugin.prompt_pre_fetch(payload, context) + + # Check that processing was blocked + assert not result.continue_processing + assert result.violation is not None + assert result.violation.code == "PII_DETECTED" + assert "input" in result.violation.details["field"] + + @pytest.mark.asyncio + async def test_prompt_post_fetch(self, plugin_config): + """Test post-fetch hook with PII in messages.""" + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-3")) + + # Create messages with PII + messages = [ + Message( + role=Role.USER, + content=TextContent( + type="text", + text="Contact me at john@example.com or 555-123-4567" + ) + ), + Message( + role=Role.ASSISTANT, + content=TextContent( + type="text", + text="I'll reach you at the provided contact: AKIAIOSFODNN7EXAMPLE" + ) + ) + ] + + payload = PromptPosthookPayload( + name="test_prompt", + result=PromptResult(messages=messages) + ) + + result = await plugin.prompt_post_fetch(payload, context) + + # Check that PII was masked in messages + assert result.modified_payload is not None + user_msg = result.modified_payload.result.messages[0].content.text + assistant_msg = result.modified_payload.result.messages[1].content.text + + assert "john@example.com" not in user_msg + assert "555-123-4567" not in user_msg + assert "AKIAIOSFODNN7EXAMPLE" not in assistant_msg + + # Check metadata + assert "pii_detections" in context.metadata + assert context.metadata["pii_detections"]["post_fetch"]["detected"] + + @pytest.mark.asyncio + async def test_no_pii_detection(self, plugin_config): + """Test that clean text passes through unchanged.""" + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-4")) + + payload = PromptPrehookPayload( + name="test_prompt", + args={"input": "This text has no sensitive information"} + ) + + result = await plugin.prompt_pre_fetch(payload, context) + + # Check that nothing was modified + assert result.modified_payload is None + assert "pii_detections" not in context.metadata + + @pytest.mark.asyncio + async def test_custom_patterns(self, plugin_config): + """Test custom PII pattern detection.""" + # Add custom pattern + plugin_config.config["custom_patterns"] = [ + { + "type": "custom", + "pattern": r"\bEMP\d{6}\b", + "description": "Employee ID", + "mask_strategy": "redact", + "enabled": True + } + ] + + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-5")) + + payload = PromptPrehookPayload( + name="test_prompt", + args={"input": "Employee ID: EMP123456"} + ) + + result = await plugin.prompt_pre_fetch(payload, context) + + # Check that custom pattern was detected and masked + assert result.modified_payload is not None + assert "EMP123456" not in result.modified_payload.args["input"] + assert "[REDACTED]" in result.modified_payload.args["input"] + + @pytest.mark.asyncio + async def test_permissive_mode(self, plugin_config): + """Test permissive mode (log but don't block).""" + plugin_config.mode = PluginMode.PERMISSIVE + plugin_config.config["block_on_detection"] = True # Should be ignored in permissive mode + + plugin = PIIFilterPlugin(plugin_config) + context = PluginContext(GlobalContext(request_id="test-6")) + + payload = PromptPrehookPayload( + name="test_prompt", + args={"input": "SSN: 123-45-6789"} + ) + + result = await plugin.prompt_pre_fetch(payload, context) + + # In permissive mode, should continue even with block_on_detection + assert result.continue_processing or plugin_config.mode == PluginMode.PERMISSIVE + # PII should still be masked + if result.modified_payload: + assert "123-45-6789" not in result.modified_payload.args["input"] + + +@pytest.mark.asyncio +async def test_integration_with_manager(): + """Test the PII Filter plugin with the plugin manager.""" + from mcpgateway.plugins.framework.manager import PluginManager + + # Create a test configuration + config_dict = { + "plugins": [ + { + "name": "PIIFilter", + "kind": "plugins.pii_filter.pii_filter.PIIFilterPlugin", + "description": "PII Filter", + "author": "Test", + "version": "1.0", + "hooks": ["prompt_pre_fetch", "prompt_post_fetch"], + "tags": ["security", "pii"], + "mode": "enforce", + "priority": 10, + "conditions": [ + { + "prompts": ["test_prompt"], + "server_ids": [], + "tenant_ids": [] + } + ], + "config": { + "detect_ssn": True, + "detect_email": True, + "default_mask_strategy": "partial", + "block_on_detection": False, + "log_detections": True, + "include_detection_details": True + } + } + ], + "plugin_dirs": [], + "plugin_settings": { + "parallel_execution_within_band": False, + "plugin_timeout": 30, + "fail_on_plugin_error": False, + "enable_plugin_api": True, + "plugin_health_check_interval": 60 + } + } + + # Save config to a temp file and initialize manager + import tempfile + import yaml + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_dict, f) + config_path = f.name + + try: + manager = PluginManager(config_path) + await manager.initialize() + + # Test with PII in prompt + payload = PromptPrehookPayload( + name="test_prompt", + args={"input": "Email: test@example.com, SSN: 123-45-6789"} + ) + + global_context = GlobalContext(request_id="test-manager") + result, contexts = await manager.prompt_pre_fetch(payload, global_context) + + # Verify PII was masked + assert result.modified_payload is not None + assert "test@example.com" not in result.modified_payload.args["input"] + assert "123-45-6789" not in result.modified_payload.args["input"] + + await manager.shutdown() + finally: + import os + os.unlink(config_path) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 4f17d66ff6bf9f13d5c1638a1c1665f094d902be Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Tue, 5 Aug 2025 16:43:31 -0600 Subject: [PATCH 15/20] docs(plugins): updated plugins documentation. Signed-off-by: Teryl Taylor --- docs/docs/using/plugins/index.md | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index d6c1e5cfb..5d1e928b3 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -46,6 +46,10 @@ PLUGIN_CONFIG_FILE=plugins/config.yaml ### 2. Plugin Configuration +The plugin configuration file is used to configure a set of plugins to run a +set of hook points throughout the MCP Context Forge. An example configuration +is below. It contains two main sections: `plugins` and `plugin_settings`. + Create or modify `plugins/config.yaml`: ```yaml @@ -78,6 +82,35 @@ plugin_settings: plugin_health_check_interval: 60 ``` +The `plugins` section lists the set of configured plugins that will be loaded +by the Context Forge at startup. Each plugin contains a set of standard configurations, +and then a `config` section designed for plugin specific configurations. The attributes +are defined as follows: + +| Attribute | Description | Example Value | +|-----------|-------------|---------------| +| **name** | A unique name for the plugin. | MyFirstPlugin | +| **kind** | A fully qualified string representing the plugin python object. | plugins.native.content_filter.ContentFilterPlugin | +| **description** | The description of the plugin configuration. | A plugin for replacing bad words. | +| **version** | The version of the plugin configuration. | 0.1 | +| **author** | The team that wrote the plugin. | MCP Context Forge | +| **hooks** | A list of hooks for which the plugin will be executed. **Note**: currently supports two hooks: "prompt_pre_fetch", "prompt_post_fetch" | ["prompt_pre_fetch", "prompt_post_fetch"] | +| **tags** | Descriptive keywords that make the configuration searchable. | ["security", "filter"] | +| **mode** | Mode of operation of the plugin. - enforce (stops during a violation), permissive (audits a violation but doesn't stop), disabled (disabled) | permissive | +| **priority** | The priority in which the plugin will run - 0 is higher priority | 100 | +| **conditions** | A list of conditions under which a plugin is run. See section on conditions.| | +| **config** | Plugin specific configuration. This is a dictionary and is passed to the plugin on initialization. | | + +The `plugin_settings` are as follows: + +| Attribute | Description | Example Value | +|-----------|-------------|---------------| +| **parallel_execution_within_band** | Plugins in the same band are run in parallel (currently not implemented). | true or false | +| **plugin_timeout** | The time in seconds before stopping plugin execution (not implemented). | 30 | +| **fail_on_plugin_error** | Cause the execution of the task to fail if the plugin errors. | true or false | +| **plugin_health_check_interval** | Health check interval in seconds (not implemented). | 60 | + + ### 3. Execution Modes Each plugin can operate in one of three modes: @@ -110,6 +143,18 @@ plugins: Plugins with the same priority may execute in parallel if `parallel_execution_within_band` is enabled. +### 5. Conditions of Execution + +Users may only want plugins to be invoked on specific servers, tools, and prompts. To address this, a set of conditionals can be applied to a plugin. The attributes in a conditional combine together in as a set of `and` operations, while each attribute list item is `ored` with other items in the list. The attributes are defined as follows: + +| Attribute | Description +|-----------|------------| +| **server_ids** | The list of MCP servers on which the plugin will trigger | +| **tools** | The list of tools on which the plugin will be applied. | +| **prompts** | The list of prompts on which the plugin will be applied. | +| **user_patterns** | The list of users on which the plugin will be applied. | +| **content_types** | The list of content types on which the plugin will trigger. | + ## Available Hooks Currently implemented hooks: From 8b4e8184afd611f6a4770dabb280c389a91645c0 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Wed, 6 Aug 2025 00:17:05 -0400 Subject: [PATCH 16/20] fix: include plugin md files in manifest Signed-off-by: Frederico Araujo --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index d810b20a2..95f430a18 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -58,6 +58,7 @@ recursive-include alembic *.py # recursive-include mcp-servers * recursive-include plugins *.py recursive-include plugins *.yaml +recursive-include plugins *.md # 5️⃣ (Optional) include MKDocs-based docs in the sdist # graft docs From b1edf3748fe4ef990ab1604080d592376d232a7f Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Wed, 6 Aug 2025 00:48:56 -0400 Subject: [PATCH 17/20] docs: add deny list plugin readme Signed-off-by: Frederico Araujo --- plugins/filter/README.md | 117 +++++++++++++++++++++++++++++++++++++++ plugins/filter/deny.py | 10 +++- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 plugins/filter/README.md diff --git a/plugins/filter/README.md b/plugins/filter/README.md new file mode 100644 index 000000000..9181f6044 --- /dev/null +++ b/plugins/filter/README.md @@ -0,0 +1,117 @@ +# Denylist Filter Plugin for MCP Gateway + +> Author: Fred Araujo + +A plugin for detecting deny words in MCP Gateway prompts. + +## Features + +Detects any deny word in the prompt. If a match is found, rejects the prompt request. + +## Installation + +1. Copy .env.example .env +2. Enable plugins in `.env` +3. Add the plugin configuration to `plugins/config.yaml`: + +```yaml +plugins: + - name: "DenyListPlugin" + kind: "plugins.filter.deny.DenyListPlugin" + description: "A plugin that implements a deny list filter." + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch"] + tags: ["plugin", "filter", "denylist", "pre-post"] + mode: "enforce" # enforce | permissive | disabled + priority: 100 + conditions: + # Apply to specific tools/servers + - prompts: ["test_prompt"] + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - innovative + - groundbreaking + - revolutionary +``` + +## Testing + +### Run Individual Tests +```bash +# Test only the plugin +pytest tests/unit/mcpgateway/plugins/framework/test_manager.py::test_manager_filter_plugins -v +``` + +### Manual Testing with the Gateway + +1. Enable the plugin in your `.env`: +```bash +PLUGINS_ENABLED=true +``` + +2. Start the gateway: +```bash +python -m mcpgateway.main +``` + +3. Test with curl: +```bash +# Test PII detection in prompt arguments +curl -X POST http://localhost:8000/prompts/test_prompt \ + -H "Content-Type: application/json" \ + -d '{ + "args": { + "user":"say the word revolutionary" + } + }' + +# Response should be an error with the following body: +# { +# "message":"Prompt execution arguments contains HTML tags that may cause security issues", +# "details":"Pre prompting fetch blocked by plugin DenyListPlugin: deny - Prompt not allowed (A deny word was found in the prompt)" +# } +``` + +## Sample Prompt + +Here's a prompt that trips the checks: + +```bash +export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token -u admin --secret my-test-key) + +curl -X POST -H "Authorization: Bearer $MCPGATEWAY_BEARER_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "name":"test_prompt", + "template":"Hello, {{ user }}!", + "argument_schema":{ + "type":"object", + "properties":{"user":{"type":"string"}}, + "required":["user"] + } + }' \ + http://localhost:4444/prompts +``` + +## CURL Command to Test + +```bash +export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token -u admin --secret my-test-key) + +# Then test with a prompt containing deny words +curl -X POST -H "Authorization: Bearer $MCPGATEWAY_BEARER_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"user":"say the word revolutionary"}' \ + http://localhost:4444/prompts/test_prompt +``` + +## License + +Apache-2.0 + +## Support + +For issues or questions, please open an issue in the MCP Gateway repository. diff --git a/plugins/filter/deny.py b/plugins/filter/deny.py index 034d83c03..c427bfcbe 100644 --- a/plugins/filter/deny.py +++ b/plugins/filter/deny.py @@ -8,7 +8,7 @@ This module loads configurations for plugins. """ # Standard -import re +import logging # Third-Party from pydantic import BaseModel @@ -18,6 +18,8 @@ from mcpgateway.plugins.framework.models import PluginConfig, PluginViolation from mcpgateway.plugins.framework.plugin_types import PluginContext, PromptPrehookPayload, PromptPrehookResult +logger = logging.getLogger(__name__) + class DenyListConfig(BaseModel): words: list[str] @@ -51,5 +53,11 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC code="deny", details={}, ) + logger.warning(f"Deny word detected in prompt argument '{key}'") return PromptPrehookResult(modified_payload=payload, violation=violation, continue_processing=False) return PromptPrehookResult(modified_payload=payload) + + + async def shutdown(self) -> None: + """Cleanup when plugin shuts down.""" + logger.info(f"Deny list plugin shutting down") \ No newline at end of file From 294a7fe27eeb58db8904de53548f28c71a4e29e6 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Wed, 6 Aug 2025 08:37:46 +0100 Subject: [PATCH 18/20] Pre-commit cleanup Signed-off-by: Mihai Criveti --- docs/docs/using/plugins/index.md | 4 ++-- plugins/filter/deny.py | 2 +- .../mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index 5d1e928b3..804bdbd0e 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -48,7 +48,7 @@ PLUGIN_CONFIG_FILE=plugins/config.yaml The plugin configuration file is used to configure a set of plugins to run a set of hook points throughout the MCP Context Forge. An example configuration -is below. It contains two main sections: `plugins` and `plugin_settings`. +is below. It contains two main sections: `plugins` and `plugin_settings`. Create or modify `plugins/config.yaml`: @@ -98,7 +98,7 @@ are defined as follows: | **tags** | Descriptive keywords that make the configuration searchable. | ["security", "filter"] | | **mode** | Mode of operation of the plugin. - enforce (stops during a violation), permissive (audits a violation but doesn't stop), disabled (disabled) | permissive | | **priority** | The priority in which the plugin will run - 0 is higher priority | 100 | -| **conditions** | A list of conditions under which a plugin is run. See section on conditions.| | +| **conditions** | A list of conditions under which a plugin is run. See section on conditions.| | | **config** | Plugin specific configuration. This is a dictionary and is passed to the plugin on initialization. | | The `plugin_settings` are as follows: diff --git a/plugins/filter/deny.py b/plugins/filter/deny.py index c427bfcbe..029d0d415 100644 --- a/plugins/filter/deny.py +++ b/plugins/filter/deny.py @@ -60,4 +60,4 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC async def shutdown(self) -> None: """Cleanup when plugin shuts down.""" - logger.info(f"Deny list plugin shutting down") \ No newline at end of file + logger.info(f"Deny list plugin shutting down") diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py index 8af97171b..9eae9fbd3 100644 --- a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -355,7 +355,7 @@ async def test_prompt_post_fetch(self, plugin_config): @pytest.mark.asyncio async def test_no_pii_detection(self, plugin_config): - """Test that clean text passes through unchanged.""" + """Test that clean text passes through unmodified.""" plugin = PIIFilterPlugin(plugin_config) context = PluginContext(GlobalContext(request_id="test-4")) From d24b51073532deb3ccfa9b27801e0768bc97812f Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Wed, 6 Aug 2025 09:37:17 +0100 Subject: [PATCH 19/20] Improved manager.py, add doctest and safety mechanisms to plugin framework (timeout, memory cleanup, validation) Signed-off-by: Mihai Criveti --- mcpgateway/config.py | 4 +- mcpgateway/plugins/framework/manager.py | 500 +++++++++++++++++++++--- mcpgateway/validators.py | 4 +- 3 files changed, 443 insertions(+), 65 deletions(-) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index db80a501e..28d4fd8df 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -501,9 +501,7 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = ( - r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" - ) + validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 0c0f34b96..2a27e6695 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -3,20 +3,40 @@ Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor +Authors: Teryl Taylor, Mihai Criveti Module that manages and calls plugins at hookpoints throughout the gateway. + +This module provides the core plugin management functionality including: +- Plugin lifecycle management (initialization, execution, shutdown) +- Timeout protection for plugin execution +- Context management with automatic cleanup +- Priority-based plugin ordering +- Conditional plugin execution based on prompts/servers/tenants + +Examples: + >>> # Initialize plugin manager with configuration + >>> manager = PluginManager("plugins/config.yaml") + >>> # await manager.initialize() # Called in async context + + >>> # Create test payload and context + >>> from mcpgateway.plugins.framework.plugin_types import PromptPrehookPayload, GlobalContext + >>> payload = PromptPrehookPayload(name="test", args={"user": "input"}) + >>> context = GlobalContext(request_id="123") + >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context """ # Standard +import asyncio import logging -from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar +import time +from typing import Any, Callable, Coroutine, Dict, Generic, Optional, Tuple, TypeVar # First-Party from mcpgateway.plugins.framework.base import PluginRef from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import Config, HookType, PluginCondition, PluginMode +from mcpgateway.plugins.framework.models import Config, HookType, PluginCondition, PluginMode, PluginViolation from mcpgateway.plugins.framework.plugin_types import ( GlobalContext, PluginContext, @@ -34,9 +54,50 @@ T = TypeVar("T") +# Configuration constants +DEFAULT_PLUGIN_TIMEOUT = 30 # seconds +MAX_PAYLOAD_SIZE = 1_000_000 # 1MB +CONTEXT_CLEANUP_INTERVAL = 300 # 5 minutes +CONTEXT_MAX_AGE = 3600 # 1 hour + + +class PluginTimeoutError(Exception): + """Raised when a plugin execution exceeds the timeout limit.""" + + +class PayloadSizeError(ValueError): + """Raised when a payload exceeds the maximum allowed size.""" + class PluginExecutor(Generic[T]): - """Executes a list of plugins.""" + """Executes a list of plugins with timeout protection and error handling. + + This class manages the execution of plugins in priority order, handling: + - Timeout protection for each plugin + - Context management between plugins + - Error isolation to prevent plugin failures from affecting the gateway + - Metadata aggregation from multiple plugins + + Examples: + >>> from mcpgateway.plugins.framework.plugin_types import PromptPrehookPayload + >>> executor = PluginExecutor[PromptPrehookPayload]() + >>> # In async context: + >>> # result, contexts = await executor.execute( + >>> # plugins=[plugin1, plugin2], + >>> # payload=payload, + >>> # global_context=context, + >>> # plugin_run=pre_prompt_fetch, + >>> # compare=pre_prompt_matches + >>> # ) + """ + + def __init__(self, timeout: int = DEFAULT_PLUGIN_TIMEOUT): + """Initialize the plugin executor. + + Args: + timeout: Maximum execution time per plugin in seconds. + """ + self.timeout = timeout async def execute( self, @@ -47,65 +108,171 @@ async def execute( compare: Callable[[T, list[PluginCondition], GlobalContext], bool], local_contexts: Optional[PluginContextTable] = None, ) -> tuple[PluginResult[T], PluginContextTable | None]: - """Execute a plugins hook run before a prompt is retrieved and rendered. + """Execute plugins in priority order with timeout protection. Args: - plugins: the list of plugins to execute. - payload: the payload to be analyzed. - global_context: contextual information for all plugins. - plugin_run: async function for executing plugin hook. - compare: function for comparing conditional information with context and payload. - local_contexts: context local to a single plugin. + plugins: List of plugins to execute, sorted by priority. + payload: The payload to be processed by plugins. + global_context: Shared context for all plugins containing request metadata. + plugin_run: Async function to execute a specific plugin hook. + compare: Function to check if plugin conditions match the current context. + local_contexts: Optional existing contexts from previous hook executions. Returns: - The result of the plugin's analysis, including whether the prompt can proceed. + A tuple containing: + - PluginResult with processing status, modified payload, and metadata + - PluginContextTable with updated local contexts for each plugin + + Raises: + PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE. + + Examples: + >>> # Execute plugins with timeout protection + >>> from mcpgateway.plugins.framework.models import HookType + >>> executor = PluginExecutor(timeout=30) + >>> # Assuming you have a registry instance: + >>> # plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + >>> # In async context: + >>> # result, contexts = await executor.execute( + >>> # plugins=plugins, + >>> # payload=PromptPrehookPayload(name="test", args={}), + >>> # global_context=GlobalContext(request_id="123"), + >>> # plugin_run=pre_prompt_fetch, + >>> # compare=pre_prompt_matches + >>> # ) """ if not plugins: return (PluginResult[T](modified_payload=None), None) + # Validate payload size + self._validate_payload_size(payload) + res_local_contexts = {} combined_metadata = {} current_payload: T | None = None + for pluginref in plugins: - if not pluginref.conditions or not compare(payload, pluginref.conditions, global_context): + # Check if plugin conditions match current context + if pluginref.conditions and not compare(payload, pluginref.conditions, global_context): + logger.debug(f"Skipping plugin {pluginref.name} - conditions not met") continue + + # Get or create local context for this plugin local_context_key = global_context.request_id + pluginref.uuid if local_contexts and local_context_key in local_contexts: local_context = local_contexts[local_context_key] else: local_context = PluginContext(global_context) res_local_contexts[local_context_key] = local_context - result = await plugin_run(pluginref, payload, local_context) - if result.metadata: - combined_metadata.update(result.metadata) + try: + # Execute plugin with timeout protection + result = await self._execute_with_timeout(pluginref, plugin_run, current_payload or payload, local_context) + + # Aggregate metadata from all plugins + if result.metadata: + combined_metadata.update(result.metadata) + + # Track payload modifications + if result.modified_payload is not None: + current_payload = result.modified_payload - if result.modified_payload is not None: - current_payload = result.modified_payload + # Set plugin name in violation if present + if result.violation: + result.violation.plugin_name = pluginref.plugin.name - if result.violation: - result.violation.plugin_name = pluginref.plugin.name + # Handle plugin blocking the request + if not result.continue_processing: + if pluginref.plugin.mode == PluginMode.ENFORCE: + logger.warning(f"Plugin {pluginref.plugin.name} blocked request in enforce mode") + return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), res_local_contexts) + elif pluginref.plugin.mode == PluginMode.PERMISSIVE: + logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else 'No description'}") - if not result.continue_processing: - # Check execution mode + except asyncio.TimeoutError: + logger.error(f"Plugin {pluginref.name} timed out after {self.timeout}s") if pluginref.plugin.mode == PluginMode.ENFORCE: - return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), None) - elif pluginref.plugin.mode == PluginMode.PERMISSIVE: - logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else ''}") + violation = PluginViolation( + reason="Plugin timeout", + description=f"Plugin {pluginref.name} exceeded {self.timeout}s timeout", + code="PLUGIN_TIMEOUT", + details={"timeout": self.timeout, "plugin": pluginref.name}, + ) + return (PluginResult[T](continue_processing=False, violation=violation, modified_payload=current_payload, metadata=combined_metadata), res_local_contexts) + # In permissive mode, continue with next plugin + continue + + except Exception as e: + logger.error(f"Plugin {pluginref.name} failed with error: {str(e)}", exc_info=True) + if pluginref.plugin.mode == PluginMode.ENFORCE: + violation = PluginViolation( + reason="Plugin error", description=f"Plugin {pluginref.name} encountered an error: {str(e)}", code="PLUGIN_ERROR", details={"error": str(e), "plugin": pluginref.name} + ) + return (PluginResult[T](continue_processing=False, violation=violation, modified_payload=current_payload, metadata=combined_metadata), res_local_contexts) + # In permissive mode, continue with next plugin + continue return (PluginResult[T](continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), res_local_contexts) + async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable, payload: T, context: PluginContext) -> PluginResult[T]: + """Execute a plugin with timeout protection. + + Args: + pluginref: Reference to the plugin to execute. + plugin_run: Function to execute the plugin. + payload: Payload to process. + context: Plugin execution context. + + Returns: + Result from plugin execution. + + Raises: + asyncio.TimeoutError: If plugin exceeds timeout. + """ + return await asyncio.wait_for(plugin_run(pluginref, payload, context), timeout=self.timeout) + + def _validate_payload_size(self, payload: Any) -> None: + """Validate that payload doesn't exceed size limits. + + Args: + payload: The payload to validate. + + Raises: + PayloadSizeError: If payload exceeds MAX_PAYLOAD_SIZE. + """ + # For PromptPrehookPayload, check args size + if hasattr(payload, "args") and payload.args: + total_size = sum(len(str(v)) for v in payload.args.values()) + if total_size > MAX_PAYLOAD_SIZE: + raise PayloadSizeError(f"Payload size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes") + # For PromptPosthookPayload, check result size + elif hasattr(payload, "result") and payload.result: + # Estimate size of result messages + total_size = len(str(payload.result)) + if total_size > MAX_PAYLOAD_SIZE: + raise PayloadSizeError(f"Result size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes") + async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: """Call plugin's prompt pre-fetch hook. Args: - plugin: the plugin to execute. - payload: the prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. + plugin: The plugin to execute. + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. Returns: The result of the plugin execution. + + Examples: + >>> from mcpgateway.plugins.framework.base import Plugin, PluginRef + >>> from mcpgateway.plugins.framework.plugin_types import PromptPrehookPayload, PluginContext, GlobalContext + >>> # Assuming you have a plugin instance: + >>> # plugin_ref = PluginRef(my_plugin) + >>> payload = PromptPrehookPayload(name="test", args={"key": "value"}) + >>> context = PluginContext(GlobalContext(request_id="123")) + >>> # In async context: + >>> # result = await pre_prompt_fetch(plugin_ref, payload, context) """ return await plugin.plugin.prompt_pre_fetch(payload, context) @@ -114,18 +281,60 @@ async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, c """Call plugin's prompt post-fetch hook. Args: - plugin: the plugin to execute. - payload: the prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. + plugin: The plugin to execute. + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. Returns: The result of the plugin execution. + + Examples: + >>> from mcpgateway.plugins.framework.base import Plugin, PluginRef + >>> from mcpgateway.plugins.framework.plugin_types import PromptPosthookPayload, PluginContext, GlobalContext + >>> from mcpgateway.models import PromptResult + >>> # Assuming you have a plugin instance: + >>> # plugin_ref = PluginRef(my_plugin) + >>> result = PromptResult(messages=[]) + >>> payload = PromptPosthookPayload(name="test", result=result) + >>> context = PluginContext(GlobalContext(request_id="123")) + >>> # In async context: + >>> # result = await post_prompt_fetch(plugin_ref, payload, context) """ return await plugin.plugin.prompt_post_fetch(payload, context) class PluginManager: - """Plugin manager for managing the plugin lifecycle.""" + """Plugin manager for managing the plugin lifecycle. + + This class implements a singleton pattern to ensure consistent plugin + management across the application. It handles: + - Plugin discovery and loading from configuration + - Plugin lifecycle management (initialization, execution, shutdown) + - Context management with automatic cleanup + - Hook execution orchestration + + Attributes: + config: The loaded plugin configuration. + plugin_count: Number of currently loaded plugins. + initialized: Whether the manager has been initialized. + + Examples: + >>> # Initialize plugin manager + >>> manager = PluginManager("plugins/config.yaml") + >>> # In async context: + >>> # await manager.initialize() + >>> # print(f"Loaded {manager.plugin_count} plugins") + >>> + >>> # Execute prompt hooks + >>> from mcpgateway.plugins.framework.plugin_types import PromptPrehookPayload, GlobalContext + >>> payload = PromptPrehookPayload(name="test", args={}) + >>> context = GlobalContext(request_id="req-123") + >>> # In async context: + >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) + >>> + >>> # Shutdown when done + >>> # await manager.shutdown() + """ __shared_state: dict[Any, Any] = {} _loader: PluginLoader = PluginLoader() @@ -135,22 +344,43 @@ class PluginManager: _pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]() _post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]() - def __init__(self, config: str = ""): + # Context cleanup tracking + _context_store: Dict[str, Tuple[PluginContextTable, float]] = {} + _last_cleanup: float = 0 + + def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT): """Initialize plugin manager. Args: - config: plugin configuration path. + config: Path to plugin configuration file (YAML). + timeout: Maximum execution time per plugin in seconds. + + Examples: + >>> # Initialize with configuration file + >>> manager = PluginManager("plugins/config.yaml") + + >>> # Initialize with custom timeout + >>> manager = PluginManager("plugins/config.yaml", timeout=60) """ self.__dict__ = self.__shared_state if config: self._config = ConfigLoader.load_config(config) + # Update executor timeouts + self._pre_prompt_executor.timeout = timeout + self._post_prompt_executor.timeout = timeout + + # Initialize context tracking if not already done + if not hasattr(self, "_context_store"): + self._context_store = {} + self._last_cleanup = time.time() + @property def config(self) -> Config | None: """Plugin manager configuration. Returns: - The plugin configuration. + The plugin configuration object or None if not configured. """ return self._config @@ -159,44 +389,114 @@ def plugin_count(self) -> int: """Number of plugins loaded. Returns: - The number of plugins loaded. + The number of currently loaded plugins. """ return self._registry.plugin_count @property def initialized(self) -> bool: - """Plugin manager initialized. + """Plugin manager initialization status. Returns: - True if the plugin manager is initialized. + True if the plugin manager has been initialized. """ return self._initialized async def initialize(self) -> None: - """Initialize the plugin manager. + """Initialize the plugin manager and load all configured plugins. + + This method: + 1. Loads plugin configurations from the config file + 2. Instantiates each enabled plugin + 3. Registers plugins with the registry + 4. Validates plugin initialization Raises: - ValueError: if it cannot initialize the plugin. + ValueError: If a plugin cannot be initialized or registered. + + Examples: + >>> manager = PluginManager("plugins/config.yaml") + >>> # In async context: + >>> # await manager.initialize() + >>> # Manager is now ready to execute plugins """ if self._initialized: + logger.debug("Plugin manager already initialized") return plugins = self._config.plugins if self._config and self._config.plugins else [] + loaded_count = 0 for plugin_config in plugins: if plugin_config.mode != PluginMode.DISABLED: - plugin = await self._loader.load_and_instantiate_plugin(plugin_config) - if plugin: - self._registry.register(plugin) - else: - raise ValueError(f"Unable to register and initialize plugin: {plugin_config.name}") + try: + plugin = await self._loader.load_and_instantiate_plugin(plugin_config) + if plugin: + self._registry.register(plugin) + loaded_count += 1 + logger.info(f"Loaded plugin: {plugin_config.name} (mode: {plugin_config.mode})") + else: + raise ValueError(f"Unable to instantiate plugin: {plugin_config.name}") + except Exception as e: + logger.error(f"Failed to load plugin {plugin_config.name}: {str(e)}") + raise ValueError(f"Unable to register and initialize plugin: {plugin_config.name}") from e + else: + logger.debug(f"Skipping disabled plugin: {plugin_config.name}") + self._initialized = True - logger.info(f"Plugin manager initialized with {len(self._registry.get_all_plugins())} plugins") + logger.info(f"Plugin manager initialized with {loaded_count} plugins") async def shutdown(self) -> None: - """Shutdown all plugins.""" + """Shutdown all plugins and cleanup resources. + + This method: + 1. Shuts down all registered plugins + 2. Clears the plugin registry + 3. Cleans up stored contexts + 4. Resets initialization state + + Examples: + >>> manager = PluginManager("plugins/config.yaml") + >>> # In async context: + >>> # await manager.initialize() + >>> # ... use the manager ... + >>> # await manager.shutdown() + """ + logger.info("Shutting down plugin manager") + + # Shutdown all plugins await self._registry.shutdown() + + # Clear context store + self._context_store.clear() + + # Reset state self._initialized = False + logger.info("Plugin manager shutdown complete") + + async def _cleanup_old_contexts(self) -> None: + """Remove contexts older than CONTEXT_MAX_AGE to prevent memory leaks. + + This method is called periodically during hook execution to clean up + stale contexts that are no longer needed. + """ + current_time = time.time() + + # Only cleanup every CONTEXT_CLEANUP_INTERVAL seconds + if current_time - self._last_cleanup < CONTEXT_CLEANUP_INTERVAL: + return + + # Find expired contexts + expired_keys = [key for key, (_, timestamp) in self._context_store.items() if current_time - timestamp > CONTEXT_MAX_AGE] + + # Remove expired contexts + for key in expired_keys: + del self._context_store[key] + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired plugin contexts") + + self._last_cleanup = current_time async def prompt_pre_fetch( self, @@ -204,31 +504,113 @@ async def prompt_pre_fetch( global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, ) -> tuple[PromptPrehookResult, PluginContextTable | None]: - """Plugin hook run before a prompt is retrieved and rendered. + """Execute pre-fetch hooks before a prompt is retrieved and rendered. Args: - payload: The prompt payload to be analyzed. - global_context: contextual information for all plugins. - local_contexts: context local to a single plugin. + payload: The prompt payload containing name and arguments. + global_context: Shared context for all plugins with request metadata. + local_contexts: Optional existing contexts from previous executions. Returns: - The result of the plugin's analysis, including whether the prompt can proceed. + A tuple containing: + - PromptPrehookResult with processing status and modified payload + - PluginContextTable with updated contexts for post-fetch hook + + Raises: + PayloadSizeError: If payload exceeds size limits. + + Examples: + >>> manager = PluginManager("plugins/config.yaml") + >>> # In async context: + >>> # await manager.initialize() + >>> + >>> from mcpgateway.plugins.framework.plugin_types import PromptPrehookPayload, GlobalContext + >>> payload = PromptPrehookPayload( + ... name="greeting", + ... args={"user": "Alice"} + ... ) + >>> context = GlobalContext( + ... request_id="req-123", + ... user="alice@example.com" + ... ) + >>> + >>> # In async context: + >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) + >>> # if result.continue_processing: + >>> # # Proceed with prompt processing + >>> # modified_payload = result.modified_payload or payload """ + # Cleanup old contexts periodically + await self._cleanup_old_contexts() + + # Get plugins configured for this hook plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) - return await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts) + + # Execute plugins + result = await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts) + + # Store contexts for potential reuse + if result[1]: + self._context_store[global_context.request_id] = (result[1], time.time()) + + return result async def prompt_post_fetch( self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None ) -> tuple[PromptPosthookResult, PluginContextTable | None]: - """Plugin hook run after a prompt is rendered. + """Execute post-fetch hooks after a prompt is rendered. Args: - payload: The prompt payload to be analyzed. - global_context: contextual information for all plugins. - local_contexts: context local to a single plugin. + payload: The prompt result payload containing rendered messages. + global_context: Shared context for all plugins with request metadata. + local_contexts: Optional contexts from pre-fetch hook execution. Returns: - The result of the plugin's analysis, including whether the prompt can proceed. + A tuple containing: + - PromptPosthookResult with processing status and modified result + - PluginContextTable with final contexts + + Raises: + PayloadSizeError: If payload exceeds size limits. + + Examples: + >>> # Continuing from prompt_pre_fetch example + >>> from mcpgateway.models import PromptResult, Message, TextContent, Role + >>> from mcpgateway.plugins.framework.plugin_types import PromptPosthookPayload, GlobalContext + >>> + >>> # Create a proper Message with TextContent + >>> message = Message( + ... role=Role.USER, + ... content=TextContent(type="text", text="Hello") + ... ) + >>> prompt_result = PromptResult(messages=[message]) + >>> + >>> post_payload = PromptPosthookPayload( + ... name="greeting", + ... result=prompt_result + ... ) + >>> + >>> manager = PluginManager("plugins/config.yaml") + >>> context = GlobalContext(request_id="req-123") + >>> + >>> # In async context: + >>> # result, _ = await manager.prompt_post_fetch( + >>> # post_payload, + >>> # context, + >>> # contexts # From pre_fetch + >>> # ) + >>> # if result.modified_payload: + >>> # # Use modified result + >>> # final_result = result.modified_payload.result """ + # Get plugins configured for this hook plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) - return await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts) + + # Execute plugins + result = await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts) + + # Clean up stored context after post-fetch + if global_context.request_id in self._context_store: + del self._context_store[global_context.request_id] + + return result diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index a242a54b6..c7be0987a 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -52,9 +52,7 @@ class SecurityValidator: """Configurable validation with MCP-compliant limits""" # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = ( - settings.validation_dangerous_html_pattern - ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] From 06ddb6d431a9f6a7be23c8cddea8aa752c1923d8 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Wed, 6 Aug 2025 09:53:28 +0100 Subject: [PATCH 20/20] Add README and renamed plugins with filter prefix Signed-off-by: Mihai Criveti --- mcpgateway/config.py | 4 +- mcpgateway/validators.py | 4 +- plugins/config.yaml | 10 +- plugins/{filter => deny_filter}/README.md | 3 +- plugins/{filter => deny_filter}/deny.py | 0 .../plugin-manifest.yaml | 2 +- plugins/pii_filter/README.md | 1 + plugins/regex_filter/README.md | 244 ++++++++++++++++++ .../plugin-manifest.yaml | 2 +- .../{regex => regex_filter}/search_replace.py | 0 .../configs/valid_multiple_plugins.yaml | 4 +- .../valid_multiple_plugins_filter.yaml | 4 +- .../configs/valid_single_filter_plugin.yaml | 2 +- .../fixtures/configs/valid_single_plugin.yaml | 2 +- .../framework/loader/test_plugin_loader.py | 4 +- .../plugins/framework/test_manager.py | 8 +- 16 files changed, 272 insertions(+), 22 deletions(-) rename plugins/{filter => deny_filter}/README.md (97%) rename plugins/{filter => deny_filter}/deny.py (100%) rename plugins/{filter => deny_filter}/plugin-manifest.yaml (88%) create mode 100644 plugins/regex_filter/README.md rename plugins/{regex => regex_filter}/plugin-manifest.yaml (92%) rename plugins/{regex => regex_filter}/search_replace.py (100%) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 28d4fd8df..db80a501e 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -501,7 +501,9 @@ def validate_database(self) -> None: db_dir.mkdir(parents=True) # Validation patterns for safe display (configurable) - validation_dangerous_html_pattern: str = r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + validation_dangerous_html_pattern: str = ( + r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + ) validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index c7be0987a..a242a54b6 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -52,7 +52,9 @@ class SecurityValidator: """Configurable validation with MCP-compliant limits""" # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_HTML_PATTERN = ( + settings.validation_dangerous_html_pattern + ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] diff --git a/plugins/config.yaml b/plugins/config.yaml index dd5f81391..48a83dc0b 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -5,7 +5,7 @@ plugins: - name: "PIIFilterPlugin" kind: "plugins.pii_filter.pii_filter.PIIFilterPlugin" description: "Detects and masks Personally Identifiable Information" - version: "1.0" + version: "0.1.0" author: "Mihai Criveti" hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["security", "pii", "compliance", "filter", "gdpr", "hipaa"] @@ -37,9 +37,9 @@ plugins: - "555-555-5555" # Self-contained Search Replace Plugin - name: "ReplaceBadWordsPlugin" - kind: "plugins.regex.search_replace.SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing words." - version: "0.1" + version: "0.1.0" author: "MCP Context Forge Team" hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["plugin", "transformer", "regex", "search-and-replace", "pre-post"] @@ -57,9 +57,9 @@ plugins: - search: crud replace: yikes - name: "DenyListPlugin" - kind: "plugins.filter.deny.DenyListPlugin" + kind: "plugins.deny_filter.deny.DenyListPlugin" description: "A plugin that implements a deny list filter." - version: "0.1" + version: "0.1.0" author: "MCP Context Forge Team" hooks: ["prompt_pre_fetch"] tags: ["plugin", "filter", "denylist", "pre-post"] diff --git a/plugins/filter/README.md b/plugins/deny_filter/README.md similarity index 97% rename from plugins/filter/README.md rename to plugins/deny_filter/README.md index 9181f6044..b4778879b 100644 --- a/plugins/filter/README.md +++ b/plugins/deny_filter/README.md @@ -1,6 +1,7 @@ # Denylist Filter Plugin for MCP Gateway > Author: Fred Araujo +> Version: 0.1.0 A plugin for detecting deny words in MCP Gateway prompts. @@ -17,7 +18,7 @@ Detects any deny word in the prompt. If a match is found, rejects the prompt req ```yaml plugins: - name: "DenyListPlugin" - kind: "plugins.filter.deny.DenyListPlugin" + kind: "plugins.deny_filter.deny.DenyListPlugin" description: "A plugin that implements a deny list filter." version: "0.1" author: "MCP Context Forge Team" diff --git a/plugins/filter/deny.py b/plugins/deny_filter/deny.py similarity index 100% rename from plugins/filter/deny.py rename to plugins/deny_filter/deny.py diff --git a/plugins/filter/plugin-manifest.yaml b/plugins/deny_filter/plugin-manifest.yaml similarity index 88% rename from plugins/filter/plugin-manifest.yaml rename to plugins/deny_filter/plugin-manifest.yaml index d6c8ae801..a8de00b87 100644 --- a/plugins/filter/plugin-manifest.yaml +++ b/plugins/deny_filter/plugin-manifest.yaml @@ -1,6 +1,6 @@ description: "Deny list plugin manifest." author: "MCP Context Forge Team" -version: "0.1" +version: "0.1.0" available_hooks: - "prompt_pre_hook" default_configs: diff --git a/plugins/pii_filter/README.md b/plugins/pii_filter/README.md index c9c80119a..79ace8a6c 100644 --- a/plugins/pii_filter/README.md +++ b/plugins/pii_filter/README.md @@ -1,6 +1,7 @@ # PII Filter Plugin for MCP Gateway > Author: Mihai Criveti +> Version: 0.1.0 A plugin for detecting and masking Personally Identifiable Information (PII) in MCP Gateway prompts and responses. diff --git a/plugins/regex_filter/README.md b/plugins/regex_filter/README.md new file mode 100644 index 000000000..b7b54c768 --- /dev/null +++ b/plugins/regex_filter/README.md @@ -0,0 +1,244 @@ +# Search Replace Plugin for MCP Gateway + +> Author: Teryl Taylor +> Version: 0.1.0 + +A native plugin for MCP Gateway that performs regex-based search and replace operations on prompt arguments and responses. + +## Features + +- **Pre-fetch Hook**: Modifies prompt arguments before prompt retrieval +- **Post-fetch Hook**: Modifies rendered prompt messages after processing +- **Regex Support**: Full regex pattern matching and replacement +- **Multiple Patterns**: Configure multiple search/replace pairs +- **Chain Transformations**: Apply replacements in sequence + +## Installation + +The plugin is included with MCP Gateway and requires no additional installation. Simply enable it in your configuration. + +## Configuration + +Add the plugin to your `plugins/config.yaml`: + +```yaml +plugins: + - name: "SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" + description: "Performs text transformations using regex patterns" + version: "0.1" + author: "MCP Context Forge Team" + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] + tags: ["transformer", "regex", "text-processing"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 # Lower = higher priority + conditions: + - prompts: ["test_prompt", "chat_prompt"] # Apply to specific prompts + server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + words: + - search: "crap" + replace: "crud" + - search: "damn" + replace: "darn" + - search: "\\bAI\\b" # Word boundary regex + replace: "artificial intelligence" +``` + +## Configuration Options + +### Plugin Settings + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | Yes | Unique plugin identifier | +| `kind` | string | Yes | Plugin class path | +| `hooks` | array | Yes | Hook points to enable | +| `mode` | string | No | Execution mode: `enforce`, `permissive`, or `disabled` | +| `priority` | integer | No | Execution order (default: 150) | +| `conditions` | array | No | Conditional execution rules | + +### Search/Replace Configuration + +| Field | Type | Description | +|-------|------|-------------| +| `words` | array | List of search/replace pairs | +| `words[].search` | string | Regex pattern to search for | +| `words[].replace` | string | Replacement text | + +## Usage Examples + +### Basic Word Replacement + +```yaml +config: + words: + - search: "hello" + replace: "greetings" + - search: "goodbye" + replace: "farewell" +``` + +### Regex Pattern Matching + +```yaml +config: + words: + # Replace email addresses with placeholder + - search: "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" + replace: "[email]" + + # Replace phone numbers + - search: "\\b\\d{3}-\\d{3}-\\d{4}\\b" + replace: "[phone]" + + # Case-insensitive replacement + - search: "(?i)microsoft" + replace: "MS" +``` + +### Chained Transformations + +```yaml +config: + words: + # These apply in order + - search: "bad" + replace: "not good" + - search: "not good" + replace: "could be better" + # Result: "bad" → "not good" → "could be better" +``` + +## How It Works + +### Pre-fetch Hook +1. Receives prompt name and arguments +2. Applies all configured patterns to each argument value +3. Returns modified arguments for prompt rendering + +### Post-fetch Hook +1. Receives rendered prompt messages +2. Applies patterns to message content +3. Returns modified messages + +## Testing + +### Manual Testing + +1. Enable the plugin in your configuration +2. Create a test prompt: +```bash +curl -X POST http://localhost:4444/prompts \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "test_prompt", + "template": "User said: {{ message }}", + "argument_schema": { + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"] + } + }' +``` + +3. Test the replacement: +```bash +curl -X GET http://localhost:4444/prompts/test_prompt \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"message": "This is crap"}' + +# Expected: "User said: This is crud" +``` + +### Unit Testing + +```python +import pytest +from plugins.regex_filter.search_replace import SearchReplacePlugin, SearchReplaceConfig + +@pytest.mark.asyncio +async def test_search_replace(): + config = PluginConfig( + name="test", + kind="plugins.regex_filter.search_replace.SearchReplacePlugin", + version="0.1", + hooks=["prompt_pre_fetch"], + config={ + "words": [ + {"search": "foo", "replace": "bar"} + ] + } + ) + + plugin = SearchReplacePlugin(config) + payload = PromptPrehookPayload( + name="test", + args={"message": "foo is foo"} + ) + + result = await plugin.prompt_pre_fetch(payload, context) + assert result.modified_payload.args["message"] == "bar is bar" +``` + +## Performance Considerations + +- Patterns are compiled once during initialization +- Regex complexity affects performance +- Consider priority when chaining with other plugins +- Use specific prompt conditions to limit scope + +## Common Use Cases + +1. **Profanity Filter**: Replace inappropriate language +2. **Terminology Standardization**: Ensure consistent terms +3. **PII Redaction**: Simple pattern-based PII removal +4. **Format Normalization**: Standardize date/time formats +5. **Abbreviation Expansion**: Expand common abbreviations + +## Troubleshooting + +### Patterns Not Matching +- Check regex syntax and escaping +- Test patterns with online regex tools +- Enable debug logging to see transformations + +### Performance Issues +- Simplify complex regex patterns +- Reduce number of patterns +- Use prompt conditions to limit scope + +### Unexpected Results +- Remember patterns apply in order +- Check for overlapping patterns +- Test with simple inputs first + +## Available Hooks + +The plugin manifest declares support for: +- `prompt_pre_hook` - Before prompt retrieval +- `prompt_post_hook` - After prompt rendering +- `tool_pre_hook` - Before tool execution (not implemented) +- `tool_post_hook` - After tool execution (not implemented) + +Currently only prompt hooks are implemented. + +## Contributing + +To extend this plugin: + +1. Add new transformation strategies +2. Implement tool hooks +3. Add pattern validation +4. Create preset pattern libraries + +## License + +Apache-2.0 + +## Support + +For issues or questions, please open an issue in the MCP Gateway repository. \ No newline at end of file diff --git a/plugins/regex/plugin-manifest.yaml b/plugins/regex_filter/plugin-manifest.yaml similarity index 92% rename from plugins/regex/plugin-manifest.yaml rename to plugins/regex_filter/plugin-manifest.yaml index 8fc8f1505..78870aaf9 100644 --- a/plugins/regex/plugin-manifest.yaml +++ b/plugins/regex_filter/plugin-manifest.yaml @@ -1,6 +1,6 @@ description: "Search replace plugin manifest." author: "MCP Context Forge Team" -version: "0.1" +version: "0.1.0" available_hooks: - "prompt_pre_hook" - "prompt_post_hook" diff --git a/plugins/regex/search_replace.py b/plugins/regex_filter/search_replace.py similarity index 100% rename from plugins/regex/search_replace.py rename to plugins/regex_filter/search_replace.py diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml index 6a88124c5..53fad8d72 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins.yaml @@ -1,6 +1,6 @@ plugins: - name: "SynonymsPlugin" - kind: "plugins.regex.search_replace.SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing synonyms." version: "0.1" author: "MCP Context Forge Team" @@ -21,7 +21,7 @@ plugins: replace: sullen # Self-contained Search Replace Plugin - name: "ReplaceBadWordsPlugin" - kind: "plugins.regex.search_replace.SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing words." version: "0.1" author: "MCP Context Forge Team" diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml index bbc0fc6ad..7e6258e3a 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml @@ -1,7 +1,7 @@ plugins: # Self-contained Deny List Plugin - name: "DenyListPlugin" - kind: "plugins.filter.deny.DenyListPlugin" + kind: "plugins.deny_filter.deny.DenyListPlugin" description: "A plugin that implements a deny list filter." version: "0.1" author: "MCP Context Forge Team" @@ -21,7 +21,7 @@ plugins: - revolutionary # Self-contained Search Replace Plugin - name: "ReplaceBadWordsPlugin" - kind: "plugins.regex.search_replace.SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing words." version: "0.1" author: "MCP Context Forge Team" diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml index ba63e818b..f3e2e4fb7 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_filter_plugin.yaml @@ -1,7 +1,7 @@ plugins: # Self-contained Deny List Plugin - name: "DenyListPlugin" - kind: "plugins.filter.deny.DenyListPlugin" + kind: "plugins.deny_filter.deny.DenyListPlugin" description: "A plugin that implements a deny list filter." version: "0.1" author: "MCP Context Forge Team" diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml index 5646b7ac5..2b7db1b90 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml @@ -1,7 +1,7 @@ plugins: # Self-contained Search Replace Plugin - name: "ReplaceBadWordsPlugin" - kind: "plugins.regex.search_replace.SearchReplacePlugin" + kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing words." version: "0.1" author: "MCP Context Forge Team" diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index 475c41d2f..f5d52fb82 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -15,7 +15,7 @@ from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.models import PluginMode from mcpgateway.plugins.framework.plugin_types import GlobalContext, PluginContext, PromptPosthookPayload, PromptPrehookPayload -from plugins.regex.search_replace import SearchReplaceConfig, SearchReplacePlugin +from plugins.regex_filter.search_replace import SearchReplaceConfig, SearchReplacePlugin def test_config_loader_load(): @@ -24,7 +24,7 @@ def test_config_loader_load(): assert config assert len(config.plugins) == 1 assert config.plugins[0].name == "ReplaceBadWordsPlugin" - assert config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert config.plugins[0].kind == "plugins.regex_filter.search_replace.SearchReplacePlugin" assert config.plugins[0].description == "A plugin for finding and replacing words." assert config.plugins[0].version == "0.1" assert config.plugins[0].author == "MCP Context Forge Team" diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index f10f3841c..0b43119f0 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -12,7 +12,7 @@ from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.plugin_types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload -from plugins.regex.search_replace import SearchReplaceConfig +from plugins.regex_filter.search_replace import SearchReplaceConfig @pytest.mark.asyncio @@ -20,7 +20,7 @@ async def test_manager_single_transformer_prompt_plugin(): manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") await manager.initialize() assert manager.config.plugins[0].name == "ReplaceBadWordsPlugin" - assert manager.config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[0].kind == "plugins.regex_filter.search_replace.SearchReplacePlugin" assert manager.config.plugins[0].description == "A plugin for finding and replacing words." assert manager.config.plugins[0].version == "0.1" assert manager.config.plugins[0].author == "MCP Context Forge Team" @@ -54,7 +54,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): await manager.initialize() assert manager.initialized assert manager.config.plugins[0].name == "SynonymsPlugin" - assert manager.config.plugins[0].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[0].kind == "plugins.regex_filter.search_replace.SearchReplacePlugin" assert manager.config.plugins[0].description == "A plugin for finding and replacing synonyms." assert manager.config.plugins[0].version == "0.1" assert manager.config.plugins[0].author == "MCP Context Forge Team" @@ -66,7 +66,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): assert srconfig.words[0].search == "happy" assert srconfig.words[0].replace == "gleeful" assert manager.config.plugins[1].name == "ReplaceBadWordsPlugin" - assert manager.config.plugins[1].kind == "plugins.regex.search_replace.SearchReplacePlugin" + assert manager.config.plugins[1].kind == "plugins.regex_filter.search_replace.SearchReplacePlugin" assert manager.config.plugins[1].description == "A plugin for finding and replacing words." assert manager.config.plugins[1].version == "0.1" assert manager.config.plugins[1].author == "MCP Context Forge Team"