Skip to content

Commit 12152eb

Browse files
authored
Merge pull request #21 from eccenca/feature/pluginActions-CMEM-5576
Add plugin actions.
2 parents 265fb40 + 4d899ca commit 12152eb

File tree

3 files changed

+152
-7
lines changed

3 files changed

+152
-7
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/) and this p
88

99
### Added
1010

11-
- Added `explicit_schema` parameter to `FlexibleOutputSchema` (CMEM-6444).
12-
11+
- Custom actions for Workflow plugins (CMEM-5576).
12+
- Added explicit_schema parameter to FlexibleOutputSchema (CMEM-6444).
1313

1414
## [4.9.0] 2025-02-20
1515

cmem_plugin_base/dataintegration/description.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pkgutil import get_data
1010
from typing import Any, ClassVar
1111

12-
from cmem_plugin_base.dataintegration.plugins import TransformPlugin, WorkflowPlugin
12+
from cmem_plugin_base.dataintegration.context import PluginContext
13+
from cmem_plugin_base.dataintegration.plugins import PluginBase, TransformPlugin, WorkflowPlugin
1314
from cmem_plugin_base.dataintegration.types import (
1415
ParameterType,
1516
ParameterTypes,
@@ -96,6 +97,84 @@ def __init__( # noqa: PLR0913
9697
self.visible = visible
9798

9899

100+
class PluginAction:
101+
"""Custom plugin action.
102+
103+
Plugin actions provide additional functionality besides the default execution.
104+
They can be triggered from the plugin UI.
105+
Each action is based on a method on the plugin class. Besides the self parameter,
106+
the method can have one additional parameter of type PluginContext.
107+
The return value of the method will be converted to a string and displayed in the UI.
108+
The string may use Markdown formatting.
109+
The method may return None, in which case no output will be displayed.
110+
It may raise an exception to signal an error to the user.
111+
112+
:param name: The name of the method.
113+
:param label: A human-readable label of the action
114+
:param description: A human-readable description of the action
115+
:param icon: An optional custom icon.
116+
"""
117+
118+
def __init__(self, name: str, label: str, description: str, icon: Icon | None = None):
119+
self.name = name
120+
self.label = label
121+
self.description = description
122+
self.icon = icon
123+
self.validated = False
124+
self.provide_plugin_context = False # Will be set by validate()
125+
126+
def validate(self, plugin_class: type) -> None:
127+
"""Validate the action and set the `provide_plugin_context` boolean.
128+
129+
:param plugin_class: The plugin class
130+
"""
131+
# Get the method from the class.
132+
try:
133+
method = getattr(plugin_class, self.name)
134+
except AttributeError:
135+
raise TypeError(
136+
f"Plugin class '{plugin_class.__name__}' does not have a method named '{self.name}'"
137+
) from None
138+
if not callable(method):
139+
raise TypeError(f"'{self.name}' in class '{plugin_class.__name__}' is not a function.")
140+
141+
# Check parameters
142+
parameters = list(inspect.signature(method).parameters.values())
143+
if len(parameters) == 1:
144+
self.provide_plugin_context = False
145+
elif len(parameters) - 1 == 1:
146+
if parameters[1].annotation is PluginContext:
147+
self.provide_plugin_context = True
148+
else:
149+
raise TypeError(
150+
f"Argument of method '{self.name}' in {plugin_class.__name__} must "
151+
f"be typed PluginContext (it's {parameters[1].annotation})."
152+
)
153+
else:
154+
raise TypeError(
155+
f"Method '{self.name}' in {plugin_class.__name__} has more than one"
156+
f" argument (besides 'self')."
157+
)
158+
self.validated = True
159+
160+
def execute(self, plugin: PluginBase, context: PluginContext) -> str | None:
161+
"""Call the action.
162+
163+
:param plugin: The plugin instance on which the action is called.
164+
:param context: The plugin context
165+
:return: The result of the action as string
166+
"""
167+
if not self.validated:
168+
raise ValueError("Action must be validated before it can be executed.")
169+
if self.provide_plugin_context:
170+
result = getattr(plugin, self.name)(context)
171+
else:
172+
result = getattr(plugin, self.name)()
173+
if result is None:
174+
return None
175+
return str(result)
176+
177+
99178
class PluginDescription:
100179
"""A plugin description.
101180
@@ -106,6 +185,7 @@ class PluginDescription:
106185
:param categories: The categories to which this plugin belongs to.
107186
:param parameters: Available plugin parameters
108187
:param icon: An optional custom plugin icon.
188+
:param actions: Custom plugin actions.
109189
"""
110190

111191
def __init__( # noqa: PLR0913
@@ -118,6 +198,7 @@ def __init__( # noqa: PLR0913
118198
categories: list[str] | None = None,
119199
parameters: list[PluginParameter] | None = None,
120200
icon: Icon | None = None,
201+
actions: list[PluginAction] | None = None,
121202
) -> None:
122203
# Set the type of the plugin. Same as the class name of the plugin
123204
# base class, e.g., 'WorkflowPlugin'.
@@ -153,6 +234,12 @@ def __init__( # noqa: PLR0913
153234
else:
154235
self.parameters = parameters
155236
self.icon = icon
237+
if actions is None:
238+
self.actions = []
239+
else:
240+
self.actions = actions
241+
for action in self.actions:
242+
action.validate(plugin_class)
156243

157244

158245
@dataclass
@@ -233,6 +320,7 @@ class Plugin:
233320
:param categories: The categories to which this plugin belongs to.
234321
:param parameters: Available plugin parameters.
235322
:param icon: Optional custom plugin icon.
323+
:param actions: Custom plugin actions
236324
"""
237325

238326
plugins: ClassVar[list[PluginDescription]] = []
@@ -246,12 +334,14 @@ def __init__( # noqa: PLR0913
246334
categories: list[str] | None = None,
247335
parameters: list[PluginParameter] | None = None,
248336
icon: Icon | None = None,
337+
actions: list[PluginAction] | None = None,
249338
):
250339
self.label = label
251340
self.description = description
252341
self.documentation = documentation
253342
self.plugin_id = plugin_id
254343
self.icon = icon
344+
self.actions = actions
255345
if categories is None:
256346
self.categories = []
257347
else:
@@ -272,6 +362,7 @@ def __call__(self, func: type):
272362
categories=self.categories,
273363
parameters=self.retrieve_parameters(func),
274364
icon=self.icon,
365+
actions=self.actions,
275366
)
276367
Plugin.plugins.append(plugin_desc)
277368
return func

tests/test_description.py

+58-4
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@
33
import unittest
44
from collections.abc import Sequence
55

6-
from cmem_plugin_base.dataintegration.description import Plugin
7-
from cmem_plugin_base.dataintegration.plugins import TransformPlugin
6+
from cmem_plugin_base.dataintegration.context import ExecutionContext, PluginContext
7+
from cmem_plugin_base.dataintegration.description import Plugin, PluginAction
8+
from cmem_plugin_base.dataintegration.entity import Entities
9+
from cmem_plugin_base.dataintegration.plugins import TransformPlugin, WorkflowPlugin
810
from cmem_plugin_base.dataintegration.types import (
911
BoolParameterType,
1012
FloatParameterType,
1113
StringParameterType,
1214
)
15+
from cmem_plugin_base.testing import TestPluginContext
1316

1417

1518
class PluginTest(unittest.TestCase):
1619
"""Plugin Test Class"""
1720

1821
def test__basic_parameters(self) -> None:
1922
"""Test basic parameters"""
20-
Plugin.plugins = []
23+
Plugin.plugins = [] # Remove all previous plugins
2124

2225
@Plugin(label="My Transform Plugin")
2326
class MyTransformPlugin(TransformPlugin):
@@ -38,7 +41,6 @@ def __init__(
3841
def transform(self, inputs: Sequence[Sequence[str]]) -> Sequence[str]:
3942
return []
4043

41-
_ = MyTransformPlugin
4244
plugin = Plugin.plugins[0]
4345

4446
no_default_par = plugin.parameters[0]
@@ -61,6 +63,58 @@ def transform(self, inputs: Sequence[Sequence[str]]) -> Sequence[str]:
6163
assert bool_par.param_type.name == BoolParameterType.name
6264
assert bool_par.default_value is True
6365

66+
def test__actions(self) -> None:
67+
"""Test plugin actions"""
68+
Plugin.plugins = [] # Remove all previous plugins
69+
70+
@Plugin(
71+
label="My Workflow Plugin",
72+
actions=[
73+
PluginAction(
74+
name="get_name",
75+
label="Get name",
76+
description="Returns the supplied name",
77+
),
78+
PluginAction(
79+
name="get_project",
80+
label="Get project",
81+
description="Returns the current project.",
82+
),
83+
],
84+
)
85+
class MyWorkflowPlugin(WorkflowPlugin):
86+
"""Test workflow plugin"""
87+
88+
def __init__(self, name: str) -> None:
89+
self.name = name
90+
91+
def execute(self, inputs: Sequence[Entities], context: ExecutionContext) -> Entities:
92+
return inputs[0]
93+
94+
def get_name(self) -> str:
95+
return self.name
96+
97+
def get_project(self, context: PluginContext) -> str:
98+
return context.project_id
99+
100+
# Get plugin description
101+
plugin = Plugin.plugins[0]
102+
103+
# There should be two actions
104+
assert len(plugin.actions) == 2
105+
action1 = plugin.actions[0]
106+
action2 = plugin.actions[1]
107+
108+
# Check first action
109+
assert action1.name == "get_name"
110+
assert action1.label == "Get name"
111+
assert action1.description == "Returns the supplied name"
112+
113+
# Call actions on a plugin instance
114+
plugin_instance = MyWorkflowPlugin("My Name")
115+
assert action1.execute(plugin_instance, TestPluginContext()) == "My Name"
116+
assert action2.execute(plugin_instance, TestPluginContext(project_id="movies")) == "movies"
117+
64118

65119
if __name__ == "__main__":
66120
unittest.main()

0 commit comments

Comments
 (0)