Skip to content

Commit a91eee0

Browse files
author
Teryl Taylor
committed
feat(plugins): added prompt posthook functionality with executor, fixed some linting issues, updated example plugin with posthook.
Signed-off-by: Teryl Taylor <[email protected]>
1 parent 6f84816 commit a91eee0

File tree

8 files changed

+212
-51
lines changed

8 files changed

+212
-51
lines changed

mcpgateway/plugins/framework/loader/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_config(config: str, use_jinja: bool = True) -> Config:
3636
with open(os.path.normpath(config), "r", encoding="utf-8") as file:
3737
template = file.read()
3838
if use_jinja:
39-
jinja_env = jinja2.Environment(loader=jinja2.BaseLoader())
39+
jinja_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True)
4040
rendered_template = jinja_env.from_string(template).render(env=os.environ)
4141
else:
4242
rendered_template = template

mcpgateway/plugins/framework/manager.py

Lines changed: 139 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,50 +10,164 @@
1010

1111
# Standard
1212
import logging
13-
from typing import Optional
13+
from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar
1414

1515
# First-Party
16+
from mcpgateway.plugins.framework.base import PluginRef
1617
from mcpgateway.plugins.framework.loader.config import ConfigLoader
1718
from mcpgateway.plugins.framework.loader.plugin import PluginLoader
18-
from mcpgateway.plugins.framework.models import Config, HookType, PluginMode
19+
from mcpgateway.plugins.framework.models import Config, HookType, PluginCondition, PluginMode
1920
from mcpgateway.plugins.framework.registry import PluginInstanceRegistry
2021
from mcpgateway.plugins.framework.types import (
2122
GlobalContext,
2223
PluginContext,
2324
PluginContextTable,
25+
PluginResult,
2426
PromptPosthookPayload,
2527
PromptPosthookResult,
2628
PromptPrehookPayload,
2729
PromptPrehookResult,
2830
)
29-
from mcpgateway.plugins.framework.utils import pre_prompt_matches
31+
from mcpgateway.plugins.framework.utils import post_prompt_matches, pre_prompt_matches
3032

3133
logger = logging.getLogger(__name__)
3234

35+
T = TypeVar('T')
36+
37+
38+
class PluginExecutor(Generic[T]):
39+
"""Executes a list of plugins."""
40+
async def execute(
41+
self,
42+
plugins: list[PluginRef],
43+
payload: T,
44+
global_context: GlobalContext,
45+
plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]],
46+
compare: Callable[[T, list[PluginCondition], GlobalContext], bool],
47+
local_contexts: Optional[PluginContextTable] = None,
48+
) -> tuple[PluginResult[T] | None, PluginContextTable | None]:
49+
"""Execute a plugins hook run before a prompt is retrieved and rendered.
50+
51+
Args:
52+
plugins: the list of plugins to execute.
53+
payload: the payload to be analyzed.
54+
global_context: contextual information for all plugins.
55+
plugin_run: async function for executing plugin hook.
56+
compare: function for comparing conditional information with context and payload
57+
local_contexts: context local to a single plugin.
58+
59+
Returns:
60+
The result of the plugin's analysis, including whether the prompt can proceed.
61+
"""
62+
if not plugins:
63+
return (PluginResult[T](modified_payload=None), None)
64+
65+
res_local_contexts = {}
66+
combined_metadata = {}
67+
current_payload: T | None = None
68+
for pluginref in plugins:
69+
if not pluginref.conditions or not compare(payload, pluginref.conditions, global_context):
70+
continue
71+
local_context_key = global_context.request_id + pluginref.uuid
72+
if local_contexts and local_context_key in local_contexts:
73+
local_context = local_contexts[local_context_key]
74+
else:
75+
local_context = PluginContext(global_context)
76+
res_local_contexts[local_context_key] = local_context
77+
result = await plugin_run(pluginref, payload, local_context)
78+
79+
if result.metadata:
80+
combined_metadata.update(result.metadata)
81+
82+
if result.modified_payload is not None:
83+
current_payload = result.modified_payload
84+
85+
if not result.continue_processing:
86+
# Check execution mode
87+
if pluginref.plugin.mode == PluginMode.ENFORCE:
88+
return (PluginResult[T](continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None)
89+
elif pluginref.plugin.mode == PluginMode.PERMISSIVE:
90+
logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}")
91+
92+
return (PluginResult[T](continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts)
93+
94+
95+
async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult:
96+
"""Call plugin's prompt pre-fetch hook.
97+
98+
Args:
99+
plugin: the plugin to execute.
100+
payload: the prompt payload to be analyzed.
101+
context: contextual information about the hook call. Including why it was called.
102+
103+
Returns:
104+
The result of the plugin execution.
105+
"""
106+
return await plugin.plugin.prompt_pre_fetch(payload, context)
107+
108+
109+
async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult:
110+
"""Call plugin's prompt post-fetch hook.
111+
112+
Args:
113+
plugin: the plugin to execute.
114+
payload: the prompt payload to be analyzed.
115+
context: contextual information about the hook call. Including why it was called.
116+
117+
Returns:
118+
The result of the plugin execution.
119+
"""
120+
return await plugin.plugin.prompt_post_fetch(payload, context)
121+
33122

34123
class PluginManager:
35124
"""Plugin manager for managing the plugin lifecycle."""
36125

37-
def __init__(self, config: str):
126+
__shared_state: dict[Any, Any] = {}
127+
_loader: PluginLoader = PluginLoader()
128+
_initialized: bool = False
129+
_registry: PluginInstanceRegistry = PluginInstanceRegistry()
130+
_config: Config | None = None
131+
_pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]()
132+
_post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]()
133+
134+
def __init__(self, config: str = ""):
38135
"""Initialize plugin manager.
39136
40137
Args:
41138
config: plugin configuration path.
42139
"""
43-
self._config: Config = ConfigLoader.load_config(config)
44-
self._initialized: bool = False
45-
self._loader: PluginLoader = PluginLoader()
46-
self._registry: PluginInstanceRegistry = PluginInstanceRegistry()
140+
self.__dict__ = self.__shared_state
141+
if config:
142+
self._config = ConfigLoader.load_config(config)
47143

48144
@property
49-
def config(self) -> Config:
145+
def config(self) -> Config | None:
50146
"""Plugin manager configuration.
51147
52148
Returns:
53149
The plugin configuration.
54150
"""
55151
return self._config
56152

153+
@property
154+
def plugin_count(self) -> int:
155+
"""Number of plugins loaded.
156+
157+
Returns:
158+
The number of plugins loaded.
159+
"""
160+
return self._registry.plugin_count
161+
162+
@property
163+
def initialized(self) -> bool:
164+
"""Plugin manager initialized.
165+
166+
Returns:
167+
True if the plugin manager is initialized.
168+
"""
169+
return self._initialized
170+
57171
async def initialize(self) -> None:
58172
"""Initialize the plugin manager.
59173
@@ -62,8 +176,10 @@ async def initialize(self) -> None:
62176
"""
63177
if self._initialized:
64178
return
179+
180+
plugins = self._config.plugins if self._config else []
65181

66-
for plugin_config in self._config.plugins:
182+
for plugin_config in plugins:
67183
if plugin_config.mode != PluginMode.DISABLED:
68184
plugin = await self._loader.load_and_instantiate_plugin(plugin_config)
69185
if plugin:
@@ -73,6 +189,16 @@ async def initialize(self) -> None:
73189
self._initialized = True
74190
logger.info(f"Plugin manager initialized with {len(self._registry.get_all_plugins())} plugins")
75191

192+
async def shutdown(self) -> None:
193+
"""Shutdown all plugins."""
194+
for plugin_ref in self._registry.get_all_plugins():
195+
try:
196+
await plugin_ref.plugin.shutdown()
197+
except Exception as e:
198+
logger.error(f"Error shutting down plugin {plugin_ref.plugin.name}: {e}")
199+
200+
self._initialized = False
201+
76202
async def prompt_pre_fetch(
77203
self,
78204
payload: PromptPrehookPayload,
@@ -90,38 +216,9 @@ async def prompt_pre_fetch(
90216
The result of the plugin's analysis, including whether the prompt can proceed.
91217
"""
92218
plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH)
219+
return await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts)
93220

94-
if not plugins:
95-
return (PromptPrehookResult(modified_payload=payload), None)
96-
97-
res_local_contexts = {}
98-
combined_metadata = {}
99-
current_payload: PromptPrehookPayload | None = None
100-
for pluginref in plugins:
101-
if not pluginref.conditions or not pre_prompt_matches(payload, pluginref.conditions, global_context):
102-
continue
103-
local_context_key = global_context.request_id + pluginref.uuid
104-
if local_contexts and local_context_key in local_contexts:
105-
local_context = local_contexts[local_context_key]
106-
else:
107-
local_context = PluginContext(global_context)
108-
res_local_contexts[local_context_key] = local_context
109-
result = await pluginref.plugin.prompt_pre_fetch(payload, local_context)
110-
111-
if result.metadata:
112-
combined_metadata.update(result.metadata)
113-
114-
if result.modified_payload is not None:
115-
current_payload = result.modified_payload
116-
117-
if not result.continue_processing:
118-
# Check execution mode
119-
if pluginref.plugin.mode == PluginMode.ENFORCE:
120-
return (PromptPrehookResult(continue_processing=False, modified_payload=current_payload, error=result.error, metadata=combined_metadata), None)
121-
elif pluginref.plugin.mode == PluginMode.PERMISSIVE:
122-
logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.error}")
123221

124-
return (PromptPrehookResult(continue_processing=True, modified_payload=current_payload, error=None, metadata=combined_metadata), res_local_contexts)
125222

126223
async def prompt_post_fetch(
127224
self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None
@@ -136,4 +233,5 @@ async def prompt_post_fetch(
136233
Returns:
137234
The result of the plugin's analysis, including whether the prompt can proceed.
138235
"""
139-
return (None, None)
236+
plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH)
237+
return await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts)

mcpgateway/plugins/framework/registry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,11 @@ def get_all_plugins(self) -> list[PluginRef]:
105105
A list of registered plugin instances.
106106
"""
107107
return list(self._plugins.values())
108+
109+
def plugin_count(self) -> int:
110+
"""Return the number of plugins registered.
111+
112+
Returns:
113+
The number of plugins registered.
114+
"""
115+
return len(self._plugins)

mcpgateway/plugins/framework/types.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,21 @@ def __init__(self, name: str, args: Optional[dict[str, str]]):
3030
args: The prompt arguments for rendering.
3131
"""
3232
self.name = name
33-
self.args = args
33+
self.args = args or {}
3434

3535

36-
PromptPosthookPayload = PromptResult
36+
class PromptPosthookPayload:
37+
"""A prompt payload for a prompt posthook."""
38+
39+
def __init__(self, name: str, result: PromptResult):
40+
"""Initialize a prompt posthook payload.
41+
42+
Args:
43+
name: The prompt name.
44+
result: The prompt Prompt Result.
45+
"""
46+
self.name = name
47+
self.result = result
3748

3849

3950
class PluginResult(Generic[T]):
@@ -66,7 +77,7 @@ def __init__(
6677
request_id: str,
6778
user: Optional[str] = None,
6879
tenant_id: Optional[str] = None,
69-
server_id: Optional[str] = None,
80+
server_id: Optional[str] = None
7081
) -> None:
7182
"""Initialize a global context.
7283
@@ -81,7 +92,6 @@ def __init__(
8192
self.tenant_id = tenant_id
8293
self.server_id = server_id
8394

84-
8595
class PluginContext(GlobalContext):
8696
"""The plugin's context, which lasts a request lifecycle.
8797

mcpgateway/plugins/framework/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# First-Party
1818
from mcpgateway.plugins.framework.models import PluginCondition
19-
from mcpgateway.plugins.framework.types import GlobalContext, PromptPrehookPayload
19+
from mcpgateway.plugins.framework.types import GlobalContext, PromptPosthookPayload, PromptPrehookPayload
2020

2121

2222
@cache # noqa
@@ -95,3 +95,27 @@ def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCon
9595
elif index < len(conditions) - 1:
9696
current_result = True
9797
return current_result
98+
99+
def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool:
100+
"""Check for a match on pre-prompt hooks.
101+
102+
Args:
103+
payload: the prompt posthook payload.
104+
conditions: the conditions on the plugin that are required for execution.
105+
context: the global context.
106+
107+
Returns:
108+
True if the plugin matches criteria.
109+
"""
110+
current_result = True
111+
for index, condition in enumerate(conditions):
112+
if not matches(condition, context):
113+
current_result = False
114+
115+
if condition.prompts and payload.name not in condition.prompts:
116+
current_result = False
117+
if current_result:
118+
return True
119+
elif index < len(conditions) - 1:
120+
current_result = True
121+
return current_result

plugins/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ plugins:
2020
words:
2121
- search: crap
2222
replace: crud
23+
- search: crud
24+
replace: yikes
2325

2426

2527
# Plugin directories to scan

plugins/regex/plugin-manifest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ available_hooks:
66
- "prompt_post_hook"
77
- "tool_pre_hook"
88
- "tool_post_hook"
9-
default_configs:
9+
default_configs:

0 commit comments

Comments
 (0)