diff --git a/dev-requirements.txt b/dev-requirements.txt index 828ef8aa4..ee12592ff 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,7 +2,9 @@ mypy>=1.2.0 mypy-extensions>=0.4.3 mypy-protobuf>=2.9 tox>=4.3.0 +pip>=23.0.0 coverage>=5.3 +pytest wheel # used in unit test only opentelemetry-sdk diff --git a/examples/workflow-async/README.md b/examples/workflow-async/README.md new file mode 100644 index 000000000..ca0f88fed --- /dev/null +++ b/examples/workflow-async/README.md @@ -0,0 +1,86 @@ +# Dapr Workflow Async Examples (Python) + +These examples mirror `examples/workflow/` but author orchestrators with `async def` using the +async workflow APIs. Activities can be either sync or async functions. + +## Prerequisites + +- [Dapr CLI and initialized environment](https://docs.dapr.io/getting-started) +- [Install Python 3.10+](https://www.python.org/downloads/) + + +How to run: +- Install Dapr CLI: `brew install dapr/tap/dapr-cli` or `choco install dapr-cli` +- Initialize Dapr: `dapr init` +- Install requirements: + ```bash + cd examples/workflow-async + python -m venv .venv + source .venv/bin/activate + pip install -r requirements.txt + ``` + + or better yet with faster `uv`: + ```bash + uv venv .venv + source .venv/bin/activate + uv pip install -r requirements.txt + ``` +- Run any example with dapr: + - `dapr run --app-id wf_async_symple -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python simple.py` + - `dapr run --app-id wf_task_chain -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python task_chaining.py` + - `dapr run --app-id wf_async_child -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python child_workflow.py` + - `dapr run --app-id wf_async_fafi -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python fan_out_fan_in.py` + - `dapr run --app-id wf_async_gather -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python fan_out_fan_in_with_gather.py` + - `dapr run --app-id wf_async_approval -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python human_approval.py` + - `dapr run --app-id wf_ctx_interceptors -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python context_interceptors_example.py` + - `dapr run --app-id wf_async_http -- /Users/filinto/diagrid/python-sdk/examples/workflow-async/.venv/bin/python async_http_activity.py` + +## Examples + +- **simple.py**: Comprehensive example showing activities, child workflows, retry policies, and external events +- **task_chaining.py**: Sequential activity calls where each result feeds into the next +- **child_workflow.py**: Parent workflow calling a child workflow +- **fan_out_fan_in.py**: Parallel activity execution pattern +- **fan_out_fan_in_with_gather.py**: Parallel execution using `ctx.when_all()` +- **human_approval.py**: Workflow waiting for external event to proceed +- **context_interceptors_example.py**: Context propagation using interceptors (tenant, request ID, etc.) +- **async_http_activity.py**: Async activities performing I/O-bound operations (HTTP requests with aiohttp) + +Notes: +- Orchestrators use `await ctx.activity(...)`, `await ctx.create_timer(...)`, `await ctx.when_all/when_any(...)`, etc. +- No event loop is started manually; the Durable Task worker drives the async orchestrators. +- You can also launch instances using `DaprWorkflowClient` as in the non-async examples. +- The interceptors example demonstrates how to propagate context (tenant, request ID) across workflow and activity boundaries using the wrapper pattern to avoid contextvar loss. + +## Async Activities + +Activities can be either synchronous or asynchronous functions. Async activities are useful for I/O-bound operations like HTTP requests, database queries, or file operations: + +```python +from dapr.ext.workflow import WorkflowActivityContext + +# Synchronous activity +@wfr.activity +def sync_activity(ctx: WorkflowActivityContext, data: str) -> str: + return data.upper() + +# Asynchronous activity +@wfr.activity +async def async_activity(ctx: WorkflowActivityContext, data: str) -> str: + # Perform async I/O operations + async with aiohttp.ClientSession() as session: + async with session.get(f"https://api.example.com/{data}") as response: + result = await response.json() + return result +``` + +Both sync and async activities are registered the same way using the `@wfr.activity` decorator. Orchestrators call them identically regardless of whether they're sync or async - the SDK handles the execution automatically. + +**When to use async activities:** +- HTTP requests or API calls +- Database queries +- File I/O operations +- Any I/O-bound work that benefits from async/await + +See `async_http_activity.py` for a complete example. diff --git a/examples/workflow-async/async_http_activity.py b/examples/workflow-async/async_http_activity.py new file mode 100644 index 000000000..ce400a212 --- /dev/null +++ b/examples/workflow-async/async_http_activity.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( # noqa: E402 + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, + WorkflowStatus, +) + +"""Example demonstrating async activities with HTTP requests. + +This example shows how to use async activities to perform I/O-bound operations +like HTTP requests without blocking the worker thread pool. +""" + + +wfr = WorkflowRuntime() + + +@wfr.activity(name='fetch_url') +async def fetch_url(ctx: WorkflowActivityContext, url: str) -> dict: + """Async activity that fetches data from a URL. + + This demonstrates using aiohttp for non-blocking HTTP requests. + In production, you would handle errors, timeouts, and retries. + """ + try: + import aiohttp + except ImportError: + # Fallback if aiohttp is not installed + return { + 'url': url, + 'status': 'error', + 'message': 'aiohttp not installed. Install with: pip install aiohttp', + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as response: + status = response.status + if status == 200: + # For JSON responses + try: + data = await response.json() + return {'url': url, 'status': status, 'data': data} + except Exception: + # For text responses + text = await response.text() + return { + 'url': url, + 'status': status, + 'length': len(text), + 'preview': text[:100], + } + else: + return {'url': url, 'status': status, 'error': 'HTTP error'} + except Exception as e: + return {'url': url, 'status': 'error', 'message': str(e)} + + +@wfr.activity(name='process_data') +def process_data(ctx: WorkflowActivityContext, data: dict) -> dict: + """Sync activity that processes fetched data. + + This shows that sync and async activities can coexist in the same workflow. + """ + return { + 'processed': True, + 'url_count': len([k for k in data if k.startswith('url_')]), + 'summary': f'Processed {len(data)} items', + } + + +@wfr.async_workflow(name='fetch_multiple_urls_async') +async def fetch_multiple_urls(ctx: AsyncWorkflowContext, urls: list[str]) -> dict: + """Orchestrator that fetches multiple URLs in parallel using async activities. + + This demonstrates: + - Calling async activities from async workflows + - Fan-out/fan-in pattern with async activities + - Mixing async and sync activities + """ + # Fan-out: Schedule all URL fetches in parallel + fetch_tasks = [ctx.call_activity(fetch_url, input=url) for url in urls] + + # Fan-in: Wait for all to complete + results = await ctx.when_all(fetch_tasks) + + # Create a dictionary of results + data = {f'url_{i}': result for i, result in enumerate(results)} + + # Process the aggregated data with a sync activity + summary = await ctx.call_activity(process_data, input=data) + + return {'results': data, 'summary': summary} + + +def main(): + """Run the example workflow.""" + # Example URLs to fetch (using httpbin.org for testing) + test_urls = [ + 'https://httpbin.org/json', + 'https://httpbin.org/uuid', + 'https://httpbin.org/user-agent', + ] + + wfr.start() + client = DaprWorkflowClient() + + try: + instance_id = 'async_http_activity_example' + print(f'Starting workflow {instance_id}...') + + # Schedule the workflow + client.schedule_new_workflow( + workflow=fetch_multiple_urls, instance_id=instance_id, input=test_urls + ) + + # Wait for completion + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + print(f'\nWorkflow status: {wf_state.runtime_status}') + + if wf_state.runtime_status == WorkflowStatus.COMPLETED: + print(f'Workflow output: {wf_state.serialized_output}') + print('\n✓ Workflow completed successfully!') + else: + print('✗ Workflow did not complete successfully') + return 1 + + finally: + wfr.shutdown() + + return 0 + + +if __name__ == '__main__': + import sys + + sys.exit(main()) diff --git a/examples/workflow-async/child_workflow.py b/examples/workflow-async/child_workflow.py new file mode 100644 index 000000000..cb063e961 --- /dev/null +++ b/examples/workflow-async/child_workflow.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowRuntime, + WorkflowStatus, +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='child_async') +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +@wfr.async_workflow(name='parent_async') +async def parent(ctx: AsyncWorkflowContext, n: int) -> int: + r = await ctx.call_child_workflow(child, input=n) + print(f'Child workflow returned {r}') + return r + 1 + + +def main(): + with wfr: + # the context manager starts the workflow runtime on __enter__ and shutdown on __exit__ + client = DaprWorkflowClient() + instance_id = 'parent_async_instance' + client.schedule_new_workflow(workflow=parent, input=5, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + if wf_state.serialized_output != '11': + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/context_interceptors_example.py b/examples/workflow-async/context_interceptors_example.py new file mode 100644 index 000000000..b1f9ae3d6 --- /dev/null +++ b/examples/workflow-async/context_interceptors_example.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Example: Interceptors for context propagation with async workflows using metadata envelope. + +This example demonstrates the RECOMMENDED approach for context propagation: + - Use metadata envelope for durable, transparent context propagation + - ClientInterceptor sets metadata when scheduling workflows + - RuntimeInterceptor restores context from metadata before execution + - WorkflowOutboundInterceptor propagates metadata to activities/child workflows + - Use wrapper pattern with 'yield from' to keep context alive during execution + +CRITICAL: Workflow interceptors MUST use the wrapper pattern and return the result: + def execute_workflow(self, request, nxt): + def wrapper(): + setup_context() + try: + gen = nxt(request) + result = yield from gen # Keep context alive during execution + return result # MUST return to propagate workflow output + finally: + cleanup_context() + return wrapper() + +Without 'return result', the workflow output will be lost (serialized_output will be null). + +Metadata envelope approach (RECOMMENDED): +------------------------------------------ +Metadata is stored separately from user payload and transparently wrapped/unwrapped by runtime. + +Benefits: + - User code receives only the payload (never sees envelope) + - Durably persisted (survives replays, retries, continue-as-new) + - Automatic propagation across workflow → activity → child workflow boundaries + - String-only metadata enforces simple, serializable key-value structure + - Context accessible to interceptors via request.metadata + +Note: This requires a running Dapr sidecar to execute. +""" + +from __future__ import annotations + +import contextvars +from dataclasses import replace +from typing import Any, Callable + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, + WorkflowActivityContext, + WorkflowRuntime, +) + +# Context variable to carry request metadata across workflow/activity execution +_request_context: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + 'request_context', default=None +) + + +def set_request_context(ctx: dict[str, str] | None) -> None: + """Set the current context (stored in contextvar).""" + _request_context.set(ctx) + + +def get_request_context() -> dict[str, str] | None: + """Get the current context from contextvar.""" + return _request_context.get() + + +class ContextClientInterceptor(BaseClientInterceptor): + """Client interceptor that sets metadata when scheduling workflows. + + The metadata is automatically wrapped in an envelope by the runtime and + propagated durably across workflow boundaries. + """ + + def schedule_new_workflow( + self, request: ScheduleWorkflowRequest, nxt: Callable[[ScheduleWorkflowRequest], Any] + ) -> Any: + # Get current context and convert to string-only metadata + ctx = get_request_context() + metadata = ctx.copy() if ctx else {} + + print('[Client] Scheduling workflow with metadata:', metadata) + + # Set metadata on the request (runtime will wrap in envelope) + return nxt(replace(request, metadata=metadata)) + + +class ContextWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + """Workflow outbound interceptor that propagates metadata to activities and child workflows. + + The metadata is automatically wrapped in an envelope by the runtime. + """ + + def call_activity( + self, request: CallActivityRequest, nxt: Callable[[CallActivityRequest], Any] + ) -> Any: + # Get current context and convert to string-only metadata + ctx = get_request_context() + metadata = ctx.copy() if ctx else {} + + return nxt(replace(request, metadata=metadata)) + + def call_child_workflow( + self, request: CallChildWorkflowRequest, nxt: Callable[[CallChildWorkflowRequest], Any] + ) -> Any: + # Get current context and convert to string-only metadata + ctx = get_request_context() + metadata = ctx.copy() if ctx else {} + + return nxt(replace(request, metadata=metadata)) + + +class ContextRuntimeInterceptor(BaseRuntimeInterceptor): + """Runtime interceptor that restores context from metadata before execution. + + The runtime automatically unwraps the envelope and provides metadata via + request.metadata. User code receives only the original payload via request.input. + """ + + def execute_workflow( + self, request: ExecuteWorkflowRequest, nxt: Callable[[ExecuteWorkflowRequest], Any] + ) -> Any: + """ + IMPORTANT: Use wrapper pattern to keep context alive during generator execution. + + Calling nxt(request) returns a generator immediately; context must stay set + while that generator executes (including during activity calls and child workflows). + """ + + def wrapper(): + # Restore context from metadata (automatically unwrapped by runtime) + if request.metadata: + set_request_context(request.metadata) + + try: + gen = nxt(request) + result = yield from gen # Keep context alive while generator executes + return result # Must explicitly return the result from the inner generator + finally: + set_request_context(None) + + return wrapper() + + def execute_activity( + self, request: ExecuteActivityRequest, nxt: Callable[[ExecuteActivityRequest], Any] + ) -> Any: + """ + Restore context from metadata before activity execution. + + The runtime automatically unwraps the envelope and provides metadata via + request.metadata. User code receives only the original payload. + """ + # Restore context from metadata (automatically unwrapped by runtime) + if request.metadata: + set_request_context(request.metadata) + + try: + return nxt(request) + finally: + set_request_context(None) + + +# Create runtime with interceptors +wfr = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], +) + + +@wfr.activity(name='process_data') +def process_data(ctx: WorkflowActivityContext, data: dict) -> dict: + """ + Activity that accesses the restored context. + + The context was set in the runtime interceptor from metadata. + The activity receives only the user payload (data), not the envelope. + """ + request_ctx = get_request_context() + + if request_ctx is None: + return {'tenant': 'unknown', 'request_id': 'unknown', 'message': 'no message', 'data': data} + + return { + 'tenant': request_ctx.get('tenant', 'unknown'), + 'request_id': request_ctx.get('request_id', 'unknown'), + 'message': data.get('message', 'no message'), + } + + +@wfr.activity(name='aggregate_results') +def aggregate_results(ctx: WorkflowActivityContext, results: list) -> dict: + """Activity that aggregates results for the same tenant in context.""" + request_ctx = get_request_context() + tenant = request_ctx.get('tenant', 'unknown') if request_ctx else 'unknown' + request_id = request_ctx.get('request_id', 'unknown') if request_ctx else 'unknown' + tenant_results = [ + r['message'] for r in results if r['tenant'] == tenant and r['request_id'] == request_id + ] + + return { + 'tenant': tenant, + 'request_id': request_id, + 'count': len(tenant_results), + 'results': tenant_results, + } + + +@wfr.async_workflow(name='context_propagation_example') +async def context_propagation_workflow(ctx: AsyncWorkflowContext, input_data: dict) -> dict: + """ + Workflow that demonstrates context propagation to activities. + + The workflow receives only the user payload (input_data), not the envelope. + The context is accessible via get_request_context() thanks to the runtime interceptor. + + Activities are executed in parallel using when_all for better performance. + """ + request_ctx = get_request_context() + + # map-reduce pattern + + # Create activity tasks (don't await yet) - metadata will be propagated automatically + # Execute all activities in parallel and get results + results = await ctx.when_all( + [ + ctx.call_activity(process_data, input={'message': 'first task'}), + ctx.call_activity(process_data, input={'message': 'second task'}), + ctx.call_activity(process_data, input={'message': 'third task'}), + ] + ) + + # Aggregate/reduce results + final = await ctx.call_activity(aggregate_results, input=results) + + return {'final': final, 'context_was': request_ctx} + + +def main(): + """ + Demonstrates metadata envelope approach: + 1. Client sets context in contextvar + 2. Client interceptor converts context to metadata + 3. Runtime wraps metadata in envelope: {"__dapr_meta__": {...}, "__dapr_payload__": {...}} + 4. Envelope is persisted durably in workflow state + 5. Runtime unwraps envelope before execution + 6. Runtime interceptor restores context from metadata + 7. User code receives only the payload, not the envelope + """ + print('=' * 70) + print('Metadata Envelope Context Propagation Example (Async)') + print('=' * 70) + + with wfr: + # Create client with client interceptor + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + + # Set context - this will be converted to metadata by the client interceptor + set_request_context({'tenant': 'acme-corp', 'request_id': 'req-12345'}) + + # Schedule workflow with user payload (metadata is added by interceptor) + instance_id = 'context_example_async' + client.schedule_new_workflow( + workflow=context_propagation_workflow, + input={'task': 'process_orders', 'order_id': 999}, + instance_id=instance_id, + ) + + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + print('\n' + '=' * 70) + print('Workflow Result:') + print('=' * 70) + print(f'Status: {wf_state.runtime_status}') + print(f'Output: {wf_state.serialized_output}') + print('=' * 70) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/fan_out_fan_in.py b/examples/workflow-async/fan_out_fan_in.py new file mode 100644 index 000000000..976211b16 --- /dev/null +++ b/examples/workflow-async/fan_out_fan_in.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, + WorkflowStatus, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='square') +def square(ctx: WorkflowActivityContext, x: int) -> int: + return x * x + + +@wfr.async_workflow(name='fan_out_fan_in_async') +async def orchestrator(ctx: AsyncWorkflowContext): + tasks = [ctx.call_activity(square, input=i) for i in range(1, 6)] + # 1 + 4 + 9 + 16 + 25 = 55 + results = await ctx.when_all(tasks) + total = sum(results) + return total + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'fofi_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow state: {wf_state}') + wfr.shutdown() + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + if wf_state.serialized_output != '55': + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/fan_out_fan_in_with_gather.py b/examples/workflow-async/fan_out_fan_in_with_gather.py new file mode 100644 index 000000000..1f9268ee0 --- /dev/null +++ b/examples/workflow-async/fan_out_fan_in_with_gather.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import asyncio + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, + WorkflowStatus, +) + +# test using sandbox to convert asyncio methods into deterministic ones + +wfr = WorkflowRuntime() + + +@wfr.activity(name='square') +def square(ctx: WorkflowActivityContext, x: int) -> int: + return x * x + + +# workflow function auto-recognize coroutine function and converts this into wfr.async_workflow +@wfr.workflow(name='fan_out_fan_in_async') +async def orchestrator(ctx: AsyncWorkflowContext): + tasks = [ctx.call_activity(square, input=i) for i in range(1, 6)] + # 1 + 4 + 9 + 16 + 25 = 55 + results = await asyncio.gather(*tasks) + total = sum(results) + return total + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'fofi_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow state: {wf_state}') + wfr.shutdown() + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + if wf_state.serialized_output != '55': + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/human_approval.py b/examples/workflow-async/human_approval.py new file mode 100644 index 000000000..a9f3d1e7b --- /dev/null +++ b/examples/workflow-async/human_approval.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import time +from datetime import timedelta + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowRuntime, + WorkflowStatus, +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='human_approval_async') +async def orchestrator(ctx: AsyncWorkflowContext, request_id: str): + approve = ctx.wait_for_external_event(f'approve:{request_id}') + reject = ctx.wait_for_external_event(f'reject:{request_id}') + decision = await ctx.when_any( + [ + approve, + reject, + ctx.create_timer(timedelta(seconds=5)), + ] + ) + if decision == approve: + print('Decision Approved') + return request_id + if decision == reject: + print('Decision Rejected') + return 'REJECTED' + return 'TIMEOUT' + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'human_approval_async_1' + try: + # clean up previous workflow with this ID + client.terminate_workflow(instance_id) + client.purge_workflow(instance_id) + except Exception: + pass + client.schedule_new_workflow(workflow=orchestrator, input='req-1', instance_id=instance_id) + time.sleep(1) + client.raise_workflow_event(instance_id, 'approve:req-1') + # In a real scenario, raise approve/reject event from another service. + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=20) + print(f'Workflow state: {wf_state}') + + wfr.shutdown() + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + if wf_state.serialized_output != '"req-1"': + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/requirements.txt b/examples/workflow-async/requirements.txt new file mode 100644 index 000000000..b6330b015 --- /dev/null +++ b/examples/workflow-async/requirements.txt @@ -0,0 +1,13 @@ +#dapr-ext-workflow-dev>=1.15.0.dev +#dapr-dev>=1.15.0.dev + + +## --- local development: install local packages in editable mode -- + +## -- if using dev version of durabletask-python +-e ../../../durabletask-python + +## -- if using dev version of dapr-ext-workflow +-e ../.. +-e ../../ext/dapr-ext-workflow + diff --git a/examples/workflow-async/simple.py b/examples/workflow-async/simple.py new file mode 100644 index 000000000..f1797f655 --- /dev/null +++ b/examples/workflow-async/simple.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import json +from datetime import timedelta +from time import sleep + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + RetryPolicy, + WorkflowActivityContext, + WorkflowRuntime, + WorkflowStatus, +) + +counter = 0 +retry_count = 0 +child_orchestrator_string = '' +instance_id = 'asyncExampleInstanceID' +child_instance_id = 'asyncChildInstanceID' +workflow_name = 'async_hello_world_wf' +child_workflow_name = 'async_child_wf' +input_data = 'Hi Async Counter!' +event_name = 'event1' +event_data = 'eventData' + +retry_policy = RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=100), +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name=workflow_name) +async def hello_world_wf(ctx: AsyncWorkflowContext, wf_input): + # activities + result_1 = await ctx.call_activity(hello_act, input=1) + print(f'Activity 1 returned {result_1}') + result_2 = await ctx.call_activity(hello_act, input=10) + print(f'Activity 2 returned {result_2}') + result_3 = await ctx.call_activity(hello_retryable_act, retry_policy=retry_policy) + print(f'Activity 3 returned {result_3}') + result_4 = await ctx.call_child_workflow(child_retryable_wf, retry_policy=retry_policy) + print(f'Child workflow returned {result_4}') + + # Event vs timeout using when_any + event_1 = ctx.wait_for_external_event(event_name) + first = await ctx.when_any( + [ + event_1, + ctx.create_timer(timedelta(seconds=30)), + ] + ) + + # Proceed only if event won + if first == event_1: + result_5 = await ctx.call_activity(hello_act, input=100) + result_6 = await ctx.call_activity(hello_act, input=1000) + return dict( + result_1=result_1, + result_2=result_2, + result_3=result_3, + result_4=result_4, + result_5=result_5, + result_6=result_6, + ) + return 'Timeout' + + +@wfr.activity(name='async_hello_act') +def hello_act(ctx: WorkflowActivityContext, wf_input): + global counter + counter += wf_input + return f'Activity returned {wf_input}' + + +@wfr.activity(name='async_hello_retryable_act') +def hello_retryable_act(ctx: WorkflowActivityContext): + global retry_count + if (retry_count % 2) == 0: + retry_count += 1 + raise ValueError('Retryable Error') + retry_count += 1 + return f'Activity returned {retry_count}' + + +@wfr.async_workflow(name=child_workflow_name) +async def child_retryable_wf(ctx: AsyncWorkflowContext): + # Call activity with retry and simulate retryable workflow failure until certain state + child_activity_result = await ctx.call_activity( + act_for_child_wf, input='x', retry_policy=retry_policy + ) + print(f'Child activity returned {child_activity_result}') + # In a real sample, you might check state and raise to trigger retry + return 'ok' + + +@wfr.activity(name='async_act_for_child_wf') +def act_for_child_wf(ctx: WorkflowActivityContext, inp): + global child_orchestrator_string + child_orchestrator_string += inp + + +def main(): + wf_state = {} + with wfr: + wf_client = DaprWorkflowClient() + + wf_client.schedule_new_workflow( + workflow=hello_world_wf, input=input_data, instance_id=instance_id + ) + + wf_client.wait_for_workflow_start(instance_id) + + # Let initial activities run + sleep(5) + + # Raise event to continue + wf_client.raise_workflow_event( + instance_id=instance_id, event_name=event_name, data={'ok': True} + ) + + # Wait for completion + wf_state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + output = json.loads(wf_state.serialized_output) + if ( + output['result_1'] != 'Activity returned 1' + or output['result_2'] != 'Activity returned 10' + or output['result_3'] != 'Activity returned 2' + or output['result_4'] != 'ok' + or output['result_5'] != 'Activity returned 100' + or output['result_6'] != 'Activity returned 1000' + ): + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/task_chaining.py b/examples/workflow-async/task_chaining.py new file mode 100644 index 000000000..25519e30d --- /dev/null +++ b/examples/workflow-async/task_chaining.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, + WorkflowStatus, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='sum') +def sum_act(ctx: WorkflowActivityContext, nums): + return sum(nums) + + +@wfr.async_workflow(name='task_chaining_async') +async def orchestrator(ctx: AsyncWorkflowContext): + a = await ctx.call_activity(sum_act, input=[1, 2]) + b = await ctx.call_activity(sum_act, input=[a, 3]) + c = await ctx.call_activity(sum_act, input=[b, 4]) + return c + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'task_chain_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + # simple test + if wf_state.runtime_status != WorkflowStatus.COMPLETED: + print('Workflow failed with status ', wf_state.runtime_status) + exit(1) + # 1 + 2 + 3 + 4 = 10 + if wf_state.serialized_output != '10': + print('Workflow result is incorrect!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/README.md b/examples/workflow/README.md index ac70cfa84..7512eff3b 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -12,6 +12,8 @@ This directory contains examples of using the [Dapr Workflow](https://docs.dapr. You can install dapr SDK package using pip command: ```sh +python3 -m venv .venv +source .venv/bin/activate pip3 install -r requirements.txt ``` @@ -53,10 +55,14 @@ expected_stdout_lines: - "== APP == New counter value is: 111!" - "== APP == New counter value is: 1111!" - "== APP == Workflow completed! Result: Completed" -timeout_seconds: 30 +timeout_seconds: 40 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + dapr run --app-id wf-simple-example -- python3 simple.py ``` @@ -165,6 +171,10 @@ timeout_seconds: 30 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + dapr run --app-id wfexample -- python3 task_chaining.py ``` @@ -212,6 +222,10 @@ timeout_seconds: 30 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + dapr run --app-id wfexample -- python3 fan_out_fan_in.py ``` @@ -361,6 +375,10 @@ sleep: 20 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + dapr run --app-id wfexample3 python3 cross-app3.py & dapr run --app-id wfexample2 python3 cross-app2.py & dapr run --app-id wfexample1 python3 cross-app1.py @@ -403,6 +421,10 @@ sleep: 20 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + export ERROR_ACTIVITY_MODE=true dapr run --app-id wfexample3 python3 cross-app3.py & dapr run --app-id wfexample2 python3 cross-app2.py & @@ -444,6 +466,10 @@ sleep: 20 --> ```sh +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt + export ERROR_WORKFLOW_MODE=true dapr run --app-id wfexample3 python3 cross-app3.py & dapr run --app-id wfexample2 python3 cross-app2.py & diff --git a/examples/workflow/context_interceptors_example.py b/examples/workflow/context_interceptors_example.py new file mode 100644 index 000000000..3b15a523e --- /dev/null +++ b/examples/workflow/context_interceptors_example.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- + +""" +Example: Interceptors for context propagation using metadata envelope (RECOMMENDED). + +This example demonstrates the recommended approach for context propagation: + - Use metadata envelope for durable, transparent context propagation + - Implement ClientInterceptor to set metadata when scheduling workflows + - Implement RuntimeInterceptor to restore context from metadata before execution + - Implement WorkflowOutboundInterceptor to propagate metadata to activities/children + - Use the wrapper pattern with 'yield from' to keep context alive during execution + +CRITICAL: Workflow interceptors MUST use the wrapper pattern and return the result: + def execute_workflow(self, request, nxt): + def wrapper(): + setup_context() + try: + gen = nxt(request) + result = yield from gen # Keep context alive + return result # MUST return to propagate workflow output + finally: + cleanup_context() + return wrapper() + +Without 'return result', the workflow output will be lost (serialized_output will be null). + +Metadata envelope approach: +------------------------------------------ +This example uses the metadata envelope feature for production-ready context propagation. +Metadata is stored separately from the user payload and is transparent to user code. + +Envelope structure (automatically handled by the runtime): + { + "__dapr_meta__": { + "v": 1, + "metadata": {"tenant": "acme", "request_id": "r-123"} + }, + "__dapr_payload__": + } + +Benefits: + - User code never sees envelope structure (receives only the payload) + - Metadata is durably persisted (survives replays, retries, continue-as-new) + - Automatic propagation across workflow → activity → child workflow boundaries + - String-only metadata enforces simple, serializable key-value structure + - Context accessible to interceptors via request.metadata + +Note: Scheduling/running requires a Dapr sidecar. This file focuses on the wiring pattern. +""" + +from __future__ import annotations + +import contextvars +import json +from dataclasses import replace +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, + WorkflowRuntime, +) + +# Context variable to carry request metadata across workflow/activity execution +_current_ctx: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + 'wf_ctx', default=None +) + + +def set_ctx(ctx: dict[str, str] | None) -> None: + """Set the current context (stored in contextvar).""" + _current_ctx.set(ctx) + + +def get_ctx() -> dict[str, str] | None: + """Get the current context from contextvar.""" + return _current_ctx.get() + + +class ContextClientInterceptor(BaseClientInterceptor): + """Client interceptor that sets metadata when scheduling workflows. + + The metadata is automatically wrapped in an envelope by the runtime and + propagated durably across workflow boundaries. + """ + + def schedule_new_workflow( + self, request: ScheduleWorkflowRequest, nxt: Callable[[ScheduleWorkflowRequest], Any] + ) -> Any: # type: ignore[override] + # Get current context and convert to string-only metadata + ctx = get_ctx() + metadata = ctx.copy() if ctx else {} + + print('[Client] Scheduling workflow with metadata:', metadata) + + # Set metadata on the request (runtime will wrap in envelope) + return nxt(replace(request, metadata=metadata)) + + +class ContextWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + """Workflow outbound interceptor that propagates metadata to activities and child workflows. + + The metadata is automatically wrapped in an envelope by the runtime. + """ + + def call_child_workflow( + self, request: CallChildWorkflowRequest, nxt: Callable[[CallChildWorkflowRequest], Any] + ) -> Any: + # Get current context and convert to string-only metadata + ctx = get_ctx() + metadata = ctx.copy() if ctx else {} + + print('[Outbound] Calling child workflow with metadata:', metadata) + + # Use dataclasses.replace() to create a modified copy + return nxt(replace(request, metadata=metadata)) + + def call_activity( + self, request: CallActivityRequest, nxt: Callable[[CallActivityRequest], Any] + ) -> Any: + # Get current context and convert to string-only metadata + ctx = get_ctx() + metadata = ctx.copy() if ctx else {} + + print(f'[Outbound] Calling activity {request.activity_name}') + print(f' -- input: {request.input}') + print(f' -- metadata: {metadata}') + + # Use dataclasses.replace() to create a modified copy + return nxt(replace(request, metadata=metadata)) + + +class ContextRuntimeInterceptor(BaseRuntimeInterceptor): + """Runtime interceptor that restores context from metadata before execution. + + The runtime automatically unwraps the envelope and provides metadata via + request.metadata. User code receives only the original payload via request.input. + """ + + def execute_workflow( + self, request: ExecuteWorkflowRequest, nxt: Callable[[ExecuteWorkflowRequest], Any] + ) -> Any: # type: ignore[override] + """ + IMPORTANT: Use wrapper pattern to keep context alive during generator execution. + + Calling nxt(request) returns a generator immediately; context must stay set + while that generator executes (including during activity calls). + """ + + def wrapper(): + print('[Runtime] Executing workflow') + print(f' -- input (payload only): {request.input}') + print(f' -- metadata: {request.metadata}') + + # Restore context from metadata (automatically unwrapped by runtime) + if request.metadata: + set_ctx(request.metadata) + + try: + gen = nxt(request) + result = yield from gen # Keep context alive while generator executes + return result # Must explicitly return the result from the inner generator + finally: + print('[Runtime] Clearing workflow context') + set_ctx(None) + + return wrapper() + + def execute_activity( + self, request: ExecuteActivityRequest, nxt: Callable[[ExecuteActivityRequest], Any] + ) -> Any: # type: ignore[override] + """ + Restore context from metadata before activity execution. + + The runtime automatically unwraps the envelope and provides metadata via + request.metadata. User code receives only the original payload. + """ + # Restore context from metadata (automatically unwrapped by runtime) + if request.metadata: + set_ctx(request.metadata) + + try: + return nxt(request) + finally: + set_ctx(None) + + +# Example workflow and activity demonstrating context access +def activity_log(ctx, data: dict[str, Any]) -> str: # noqa: ANN001 (example) + """ + Activity that accesses the restored context. + + The context was set in the runtime interceptor from metadata. + The activity receives only the user payload (data), not the envelope. + """ + current_context = get_ctx() + + if current_context is None: + return dict(tenant='unknown', request_id='unknown', msg='no message', data=data) + + return dict( + tenant=current_context.get('tenant', 'unknown'), + request_id=current_context.get('request_id', 'unknown'), + msg=data.get('msg', 'no message'), + data=data, + ) + + +def workflow_example(ctx, wf_input: dict[str, Any]): # noqa: ANN001 (example) + """ + Example workflow that calls an activity. + + The workflow receives only the user payload (wf_input), not the envelope. + The context is accessible via get_ctx() thanks to the runtime interceptor. + """ + current_context = get_ctx() + + # Call activity - metadata will be propagated automatically via outbound interceptor + y = yield ctx.call_activity(activity_log, input={'msg': 'hello from workflow'}) + + return dict(result=y, context_was=json.dumps(current_context)) + + +def wire_up() -> tuple[WorkflowRuntime, DaprWorkflowClient]: + """Set up runtime and client with interceptors.""" + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + + # Register workflow/activity + runtime.register_workflow(workflow_example, name='example') + runtime.register_activity(activity_log, name='activity_log') + return runtime, client + + +if __name__ == '__main__': + """ + Demonstrates metadata envelope approach: + 1. Client sets context in contextvar + 2. Client interceptor converts context to metadata + 3. Runtime wraps metadata in envelope: {"__dapr_meta__": {...}, "__dapr_payload__": {...}} + 4. Envelope is persisted durably in workflow state + 5. Runtime unwraps envelope before execution + 6. Runtime interceptor restores context from metadata + 7. User code receives only the payload, not the envelope + """ + print('=' * 70) + print('Metadata Envelope Context Propagation Example') + print('=' * 70) + + wrt, client = wire_up() + with wrt: + # Set context - this will be converted to metadata by the client interceptor + set_ctx({'tenant': 'acme-corp', 'request_id': 'req-12345'}) + + # Schedule workflow with user payload (metadata is added by interceptor) + instance_id = client.schedule_new_workflow( + workflow_example, input={'operation': 'process_order', 'order_id': 999} + ) + + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + print('\n' + '=' * 70) + print('Workflow Result:') + print('=' * 70) + print(f'Status: {wf_state.runtime_status}') + print(f'Output: {wf_state.serialized_output}') + print('=' * 70) diff --git a/examples/workflow/cross-app-with-retry-interceptor-README.md b/examples/workflow/cross-app-with-retry-interceptor-README.md new file mode 100644 index 000000000..7e0f6c24d --- /dev/null +++ b/examples/workflow/cross-app-with-retry-interceptor-README.md @@ -0,0 +1,133 @@ +# Cross-App Workflow with Default Retry Policy Interceptor + +This example demonstrates how to use workflow outbound interceptors to automatically set default retry policies for activities and child workflows. + +## Features + +- **Default Retry Policies**: Interceptor automatically adds retry policies to activities that don't have one +- **App-Specific Policies**: Different retry policies based on the target `app_id` +- **Policy Preservation**: User-provided retry policies are preserved and not overridden +- **Cross-App Communication**: Demonstrates retry behavior across Dapr app boundaries + +## Architecture + +- **App1** (`cross-app-with-retry-interceptor-app1.py`): + - Runs the main workflow + - Has a custom `DefaultRetryInterceptor` that sets retry policies + - Calls both local and cross-app activities + - Demonstrates explicit retry policy specification + +- **App2** (`cross-app-with-retry-interceptor-app2.py`): + - Provides the `remote_activity` for cross-app calls + - Can simulate failures to test retry behavior + +## Interceptor Behavior + +The `DefaultRetryInterceptor` in app1: + +1. **For cross-app activities** (`app_id='wfexample-retry-app2'`): + - 3 retry attempts + - 500ms initial retry interval + - 5s max retry interval + - Exponential backoff (coefficient 2.0) + +2. **For local activities** (no `app_id` or different app): + - 2 retry attempts + - 100ms initial retry interval + - 2s max retry interval + +3. **User-provided policies**: Preserved unchanged + +## Setup + +### Prerequisites +- Dapr CLI installed +- Python 3.8+ +- Dapr Python SDK installed + +### Install Dependencies +```bash +pip install dapr-ext-workflow +``` + +## Running the Example + +### Terminal 1: Start App2 (Activity Provider) +```bash +dapr run --app-id wfexample-retry-app2 --dapr-grpc-port 50002 \ + -- python cross-app-with-retry-interceptor-app2.py +``` + +### Terminal 2: Start App1 (Main Workflow) +```bash +dapr run --app-id wfexample-retry-app1 --dapr-grpc-port 50001 \ + -- python cross-app-with-retry-interceptor-app1.py +``` + +## Testing Retry Behavior + +To see the retry policy in action, you can simulate failures: + +### Terminal 1: Start App2 with Failure Simulation +```bash +SIMULATE_FAILURE=true dapr run --app-id wfexample-retry-app2 --dapr-grpc-port 50002 --app-port 5002 \ + -- python cross-app-with-retry-interceptor-app2.py +``` + +You should see the interceptor's retry policy kick in and retry the failed activity multiple times. + +## Expected Output + +### App1 Output (Successful Case) +``` +app1 - workflow started +[Interceptor] Setting default retry policy for activity local_activity +app1 - calling local_activity +app1 - local_activity called with input: local-call +app1 - local_activity result: local-result-local-call +[Interceptor] Setting cross-app retry policy for activity remote_activity -> wfexample-retry-app2 +app1 - calling cross-app activity +app1 - remote_activity result: remote-result-cross-app-call +[Interceptor] Preserving user-provided retry policy for local_activity +app1 - calling activity with explicit retry policy +app1 - explicit retry activity result: local-result-explicit-retry +app1 - workflow completed +``` + +### App2 Output +``` +app2 - starting workflow runtime +app2 - workflow runtime started, waiting... +app2 - remote_activity called with input: cross-app-call +app2 - remote_activity completed successfully +``` + +## Key Concepts + +1. **Interceptor Chain**: Interceptors can inspect and modify requests before they're processed +2. **Retry Policies**: Control how failures are handled with automatic retries +3. **Cross-App Communication**: Activities can be called across different Dapr applications +4. **Graceful Defaults**: Provide sensible defaults while allowing explicit overrides + +## Customization + +You can customize the retry policies in the `DefaultRetryInterceptor` class: + +```python +class DefaultRetryInterceptor(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, nxt): + if request.retry_policy is None: + # Customize your default retry policy here + retry_policy = wf.RetryPolicy( + max_number_of_attempts=5, # More retries + first_retry_interval=timedelta(seconds=1), + max_retry_interval=timedelta(seconds=10), + ) +``` + +## Clean Up + +Stop both Dapr applications with Ctrl+C in each terminal. + + + diff --git a/examples/workflow/cross-app-with-retry-interceptor-app1.py b/examples/workflow/cross-app-with-retry-interceptor-app1.py new file mode 100644 index 000000000..6d94dd0ea --- /dev/null +++ b/examples/workflow/cross-app-with-retry-interceptor-app1.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Default Retry Policy Interceptor for Cross-App Activities + +This example demonstrates how to use workflow outbound interceptors to: +- Set default retry policies for all activities that don't have one +- Apply different retry policies based on the target app_id +- Preserve user-provided retry policies while adding defaults +""" + +import time +from dataclasses import replace +from datetime import timedelta +from typing import Any, Callable + +import dapr.ext.workflow as wf +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, +) + + +def print_no_replay(ctx: wf.DaprWorkflowContext): + """Returns a print function that only prints if not replaying""" + if not ctx.is_replaying: + return print + else: + return lambda *args, **kwargs: None + + +class DefaultRetryInterceptor(BaseWorkflowOutboundInterceptor): + """Interceptor that sets default retry policies for activities and child workflows. + + This demonstrates how interceptors can inspect and modify retry_policy and app_id + fields to provide consistent resilience behavior across your workflows. + """ + + def call_activity( + self, request: CallActivityRequest, nxt: Callable[[CallActivityRequest], Any] + ) -> Any: + print = print_no_replay( + request.workflow_ctx + ) # print function that only prints if not replaying + # Set default retry policy if none provided + print(f'[Interceptor] call_activity called with request: {request}') + retry_policy = request.retry_policy + if retry_policy is None: + # Apply different retry policies based on target app + if request.app_id == 'wfexample-retry-app2': + # More aggressive retry for cross-app calls + retry_policy = wf.RetryPolicy( + max_number_of_attempts=4, # 1 + 3 retries + first_retry_interval=timedelta(milliseconds=500), + max_retry_interval=timedelta(seconds=5), + backoff_coefficient=2.0, + ) + print( + f'[Interceptor] Setting cross-app retry policy for activity {request.activity_name} -> {request.app_id}', + ) + else: + # Default retry for local activities + retry_policy = wf.RetryPolicy( + max_number_of_attempts=2, + first_retry_interval=timedelta(milliseconds=100), + max_retry_interval=timedelta(seconds=2), + ) + print( + f'[Interceptor] Setting default retry policy for activity {request.activity_name}', + ) + else: + print( + f'[Interceptor] Preserving user-provided retry policy for {request.activity_name}', + ) + + # Forward with modified request using dataclasses.replace() + return nxt(replace(request, retry_policy=retry_policy)) + + def call_child_workflow( + self, request: CallChildWorkflowRequest, nxt: Callable[[CallChildWorkflowRequest], Any] + ) -> Any: + # Could also set default retry for child workflows + retry_policy = request.retry_policy + if retry_policy is None and request.app_id is not None: + retry_policy = wf.RetryPolicy( + max_number_of_attempts=2, + first_retry_interval=timedelta(milliseconds=200), + max_retry_interval=timedelta(seconds=3), + ) + print( + f'[Interceptor] Setting retry policy for child workflow {request.workflow_name} -> {request.app_id}', + flush=True, + ) + + return nxt(replace(request, retry_policy=retry_policy)) + + +# Create runtime with the interceptor +wfr = wf.WorkflowRuntime(workflow_outbound_interceptors=[DefaultRetryInterceptor()]) + + +@wfr.workflow +def app1_workflow(ctx: wf.DaprWorkflowContext): + print = print_no_replay(ctx) # print function that only prints if not replaying + print('app1 - workflow started') + + # Call local activity (will get default retry policy from interceptor) + print('app1 - calling local_activity') + result1 = yield ctx.call_activity(local_activity, input='local-call') + print(f'app1 - local_activity result: {result1}') + + # Call cross-app activity (will get cross-app retry policy from interceptor) + print('app1 - calling cross-app activity') + result2 = yield ctx.call_activity( + 'remote_activity', input='cross-app-call', app_id='wfexample-retry-app2' + ) + print(f'app1 - remote_activity result: {result2}') + + # Call activity with explicit retry policy (interceptor preserves it) + print('app1 - calling activity with explicit retry policy', flush=True) + explicit_retry = wf.RetryPolicy( + max_number_of_attempts=5, + first_retry_interval=timedelta(milliseconds=50), + max_retry_interval=timedelta(seconds=1), + ) + result3 = yield ctx.call_activity( + local_activity, input='explicit-retry', retry_policy=explicit_retry + ) + print(f'app1 - explicit retry activity result: {result3}', flush=True) + + print('app1 - workflow completed', flush=True) + return {'local': result1, 'remote': result2, 'explicit': result3} + + +@wfr.activity +def local_activity(ctx: wf.WorkflowActivityContext, input: str) -> str: + print(f'app1 - local_activity called with input: {input}', flush=True) + return f'local-result-{input}' + + +if __name__ == '__main__': + wfr.start() + time.sleep(10) # wait for workflow runtime to start + + wf_client = wf.DaprWorkflowClient() + print('app1 - scheduling workflow', flush=True) + instance_id = wf_client.schedule_new_workflow(workflow=app1_workflow) + print(f'app1 - workflow scheduled with instance_id: {instance_id}', flush=True) + + # Wait for the workflow to complete + time.sleep(30) + + # Check workflow state + state = wf_client.get_workflow_state(instance_id) + if state: + print(f'app1 - workflow status: {state.runtime_status.name}', flush=True) + if state.serialized_output: + print(f'app1 - workflow output: {state.serialized_output}', flush=True) + + wfr.shutdown() diff --git a/examples/workflow/cross-app-with-retry-interceptor-app2.py b/examples/workflow/cross-app-with-retry-interceptor-app2.py new file mode 100644 index 000000000..9300e2fda --- /dev/null +++ b/examples/workflow/cross-app-with-retry-interceptor-app2.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +App2: Remote Activity Provider + +This app provides the remote_activity that app1 calls cross-app. +It demonstrates how retry policies set by interceptors work across app boundaries. +""" + +import os +import time + +import dapr.ext.workflow as wf + + +def print_no_replay(ctx: wf.DaprWorkflowContext): + """Returns a print function that only prints if not replaying""" + if not ctx.is_replaying: + return print + else: + return lambda *args, **kwargs: None + + +RETRY_ATTEMPT = 0 + +wfr = wf.WorkflowRuntime() + + +@wfr.activity +def remote_activity(ctx: wf.WorkflowActivityContext, input: str) -> str: + print(f'app2 - remote_activity called with input: {input}', flush=True) + + # Optionally simulate failures to see retry behavior + if os.getenv('SIMULATE_FAILURE', 'false') == 'true': + global RETRY_ATTEMPT # NOTE: this is a hack to simulate a failure and see the retry behavior. DO NOT DO THIS IN PRODUCTION as it is not deterministic. + RETRY_ATTEMPT += 1 + if RETRY_ATTEMPT < 3: + print(f'app2 - simulating temporary failure, attempt {RETRY_ATTEMPT}', flush=True) + raise ValueError('Simulated activity failure') + else: + # failure resolved + print(f'app2 - simulated failure resolved, attempt {RETRY_ATTEMPT}', flush=True) + + print('app2 - remote_activity completed successfully', flush=True) + return f'remote-result-{input}' + + +if __name__ == '__main__': + print('app2 - starting workflow runtime', flush=True) + wfr.start() + print('app2 - workflow runtime started, waiting...', flush=True) + time.sleep(30) # keep running longer to serve cross-app calls + print('app2 - shutting down', flush=True) + wfr.shutdown() diff --git a/examples/workflow/requirements.txt b/examples/workflow/requirements.txt index c5af70b9d..367e80be1 100644 --- a/examples/workflow/requirements.txt +++ b/examples/workflow/requirements.txt @@ -1,2 +1,12 @@ -dapr-ext-workflow>=1.16.0.dev -dapr>=1.16.0.dev +# dapr-ext-workflow-dev>=1.16.0.dev +# dapr-dev>=1.16.0.dev + +# local development: install local packages in editable mode + +# if using dev version of durabletask-python +-e ../../../durabletask-python + +# if using dev version of dapr-ext-workflow +-e ../../ext/dapr-ext-workflow +-e ../.. + diff --git a/ext/dapr-ext-workflow/README.rst b/ext/dapr-ext-workflow/README.rst index aa0003c6e..3273f5517 100644 --- a/ext/dapr-ext-workflow/README.rst +++ b/ext/dapr-ext-workflow/README.rst @@ -16,6 +16,612 @@ Installation pip install dapr-ext-workflow +Async authoring (experimental) +------------------------------ + +This package supports authoring workflows with ``async def`` in addition to the existing generator-based orchestrators. + +**When to use async workflows:** + +Async workflows are a **special case** for integrating external async libraries into the deterministic workflow +execution path. Use async workflows when: + +- You need to call async libraries that provide deterministic operations (e.g., async graph execution frameworks, + async state machines, async DSL interpreters) +- You're building workflow orchestration on top of existing async code that must run deterministically +- You want to use ``async``/``await`` syntax with durable task operations for code clarity + +**Note:** Most workflows should use regular (generator-based) orchestrators and call async I/O through activities. +Async workflows don't run a real event loop - the durable task runtime drives execution deterministically. + +- Register async workflows using ``WorkflowRuntime.workflow`` (auto-detects coroutine) or ``async_workflow`` / ``register_async_workflow``. +- Use ``AsyncWorkflowContext`` for deterministic operations: + + - Activities: ``await ctx.call_activity(activity_fn, input=...)`` + - Child workflows: ``await ctx.call_child_workflow(workflow_fn, input=...)`` + - Timers: ``await ctx.create_timer(seconds|timedelta)`` + - External events: ``await ctx.wait_for_external_event(name)`` + - Concurrency: ``await ctx.when_all([...])``, ``await ctx.when_any([...])`` + - Deterministic utils: ``ctx.now()``, ``ctx.random()``, ``ctx.uuid4()``, ``ctx.new_guid()``, ``ctx.random_string(length)`` + +Interceptors (client/runtime/outbound) +-------------------------------------- + +Interceptors provide a simple, composable way to apply cross-cutting behavior with a single +enter/exit per call. + +**Inbound vs Outbound:** + +- **Outbound**: Calls going OUT from your code (scheduling workflows, calling activities/children) +- **Inbound**: Calls coming IN to execute your code (runtime invoking workflows/activities) + +**Three interceptor types:** + +- **Client interceptors**: wrap outbound scheduling from the client (``schedule_new_workflow``) +- **Workflow outbound interceptors**: wrap outbound calls made inside workflows (``call_activity``, ``call_child_workflow``) +- **Runtime interceptors**: wrap inbound execution when the runtime invokes workflows and activities (before user code runs) + +Use cases include context propagation, request metadata stamping, replay-aware logging, validation, +and policy enforcement. + +Response/output shaping +~~~~~~~~~~~~~~~~~~~~~~~ + +Interceptors are "around" hooks: they can shape inputs before calling ``next(...)`` and may also +shape the returned value (or map exceptions) after ``next(...)`` returns. This mirrors gRPC +interceptors and keeps the surface simple – one hook per interception point. + +- Client interceptors can transform schedule/query/signal responses. +- Runtime interceptors can transform workflow/activity results (with guardrails below). +- Workflow-outbound interceptors remain input-only to keep awaitable composition simple. + +Examples +^^^^^^^^ + +Client schedule response shaping:: + + from dapr.ext.workflow import ( + DaprWorkflowClient, ClientInterceptor, ScheduleWorkflowRequest + ) + + class ShapeId(ClientInterceptor): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + raw = next(input) + return f"tenant-A:{raw}" + + client = DaprWorkflowClient(interceptors=[ShapeId()]) + instance_id = client.schedule_new_workflow(my_workflow, input={}) + # instance_id == "tenant-A:" + +Runtime activity result shaping:: + + from dapr.ext.workflow import WorkflowRuntime, RuntimeInterceptor, ExecuteActivityRequest + + class WrapResult(RuntimeInterceptor): + def execute_activity(self, input: ExecuteActivityRequest, next): + res = next(input) + return {"value": res} + + rt = WorkflowRuntime(runtime_interceptors=[WrapResult()]) + @rt.activity + def echo(ctx, x): + return x + # echo(...) returns {"value": x} + +Determinism guardrails (workflows) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Workflow response shaping must be replay-safe: pure transforms only (no I/O, time, RNG). +- Base the transform solely on (input, metadata, original_result). Map errors to typed exceptions. +- Activities are not replayed, so result shaping may perform I/O, but keep it lightweight. + +Generator wrapper pattern (CRITICAL for workflow interceptors) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When writing ``execute_workflow`` interceptors that need to maintain state (context managers, +contextvars, logging contexts), you MUST use the wrapper pattern with ``yield from``: + +.. code-block:: python + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + # Set up your context (contextvars, logging, tracing, etc.) + setup_my_context() + try: + gen = next(input) + result = yield from gen # Keep context alive during execution + return result # CRITICAL: Must return to propagate workflow output + finally: + cleanup_my_context() + return wrapper() + +**Why this matters:** + +- ``next(input)`` returns a generator immediately (does NOT execute the workflow yet) +- The generator executes later when the durable task runtime drives it +- Without the wrapper, any context you set would be cleared before the workflow runs +- **You MUST capture and return the result** from ``yield from gen``, otherwise the workflow + output will be lost (``serialized_output`` will be ``null``) + +**Activities do NOT need this pattern** because ``next(input)`` executes immediately and +returns the result directly (no generator involved). + +Quick start +~~~~~~~~~~~ + +.. code-block:: python + + from __future__ import annotations + import contextvars + from typing import Any, Callable, List + + from dapr.ext.workflow import ( + WorkflowRuntime, + DaprWorkflowClient, + ClientInterceptor, + WorkflowOutboundInterceptor, + RuntimeInterceptor, + ScheduleWorkflowRequest, + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteWorkflowRequest, + ExecuteActivityRequest, + ) + + # Example: propagate a lightweight context dict through inputs + _current_ctx: contextvars.ContextVar[Optional[dict[str, Any]]] = contextvars.ContextVar( + 'wf_ctx', default=None + ) + + def set_ctx(ctx: Optional[dict[str, Any]]): + _current_ctx.set(ctx) + + def _merge_ctx(args: Any) -> Any: + ctx = _current_ctx.get() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + # Typed payloads + class MyWorkflowInput: + def __init__(self, question: str, tags: List[str] | None = None): + self.question = question + self.tags = tags or [] + + class MyActivityInput: + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + class ContextClientInterceptor(ClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest[MyWorkflowInput], nxt: Callable[[ScheduleWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + from dataclasses import replace + return nxt(replace(input, input=_merge_ctx(input.input))) + + class ContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow(self, input: CallChildWorkflowRequest[MyWorkflowInput], nxt: Callable[[CallChildWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + # Use dataclasses.replace() to create a modified copy + from dataclasses import replace + return nxt(replace(input, input=_merge_ctx(input.input))) + + def call_activity(self, input: CallActivityRequest[MyActivityInput], nxt: Callable[[CallActivityRequest[MyActivityInput]], Any]) -> Any: + from dataclasses import replace + return nxt(replace(input, input=_merge_ctx(input.input))) + + class ContextRuntimeInterceptor(RuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow(self, input: ExecuteWorkflowRequest[MyWorkflowInput], nxt: Callable[[ExecuteWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + # IMPORTANT: Use wrapper pattern for workflows to keep context alive during generator execution. + # nxt(input) returns a generator immediately; context must stay set during execution. + def wrapper(): + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + gen = nxt(input) + result = yield from gen # Keep context alive while generator executes + return result # MUST return result to propagate workflow output + finally: + set_ctx(None) + return wrapper() + + def execute_activity(self, input: ExecuteActivityRequest[MyActivityInput], nxt: Callable[[ExecuteActivityRequest[MyActivityInput]], Any]) -> Any: + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + # Wire into client and runtime + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + +**Tip: Using dataclasses.replace() for cleaner interceptors** + +Since interceptor request objects are dataclasses, you can use ``dataclasses.replace()`` to create +modified copies without manually copying all fields: + +.. code-block:: python + + from dataclasses import replace + + class MyInterceptor(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, nxt): + # Instead of manually listing all fields: + # return nxt(CallActivityRequest( + # activity_name=request.activity_name, + # input=request.input, + # retry_policy=modified_policy, + # app_id=request.app_id, + # workflow_ctx=request.workflow_ctx, + # metadata=request.metadata, + # )) + + # Use replace() to only specify changed fields: + return nxt(replace(request, retry_policy=modified_policy)) + +This is more concise, future-proof and less error-prone when adding or modifying fields. + +Context metadata (durable propagation) +------------------------------------- + +Interceptors support a durable context channel: + +- ``metadata``: a string-only dict that is durably persisted and propagated across workflow + boundaries (schedule, child workflows, activities). Typical use: tracing and correlation ids + (e.g., ``otel.trace_id``), tenancy, request ids. This is provider-agnostic and does not require + changes to your workflow/activities. + +How it works +~~~~~~~~~~~~ + +- Client interceptors can set ``metadata`` when scheduling a workflow or calling activities/children. +- Runtime unwraps a reserved envelope before user code runs and exposes the metadata to + ``RuntimeInterceptor`` via ``ExecuteWorkflowRequest.metadata`` / ``ExecuteActivityRequest.metadata``, + while delivering only the original payload to the user function. +- Outbound calls made inside a workflow use client interceptors; when ``metadata`` is present on the + call input, the runtime re-wraps the payload to persist and propagate it. + +Envelope structure (backward compatible) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Internally, the runtime persists metadata by wrapping inputs in an envelope before serialization +to the durable store. This envelope is transparent to user code and is automatically unwrapped +before workflow/activity execution. + +**Envelope format:** + +:: + + { + "__dapr_meta__": { + "v": 1, # Version for future compatibility + "metadata": { # String-only key-value pairs + "tenant": "acme-corp", + "request_id": "req-12345", + "otel.trace_id": "abc123...", + "custom_header": "value" + } + }, + "__dapr_payload__": # User's actual input (any JSON-serializable type) + } + +**JSON serialization example** (as stored in Dapr state/history): + +.. code-block:: json + + { + "__dapr_meta__": { + "v": 1, + "metadata": { + "tenant": "acme-corp", + "request_id": "req-12345" + } + }, + "__dapr_payload__": { + "user_id": 123, + "action": "process_order" + } + } + +**Key properties:** + +- **Automatic unwrapping**: The runtime unwraps this automatically before calling user code, so + workflows and activities receive only the ``__dapr_payload__`` content in their original structure +- **Version field**: The ``v`` field is reserved for forward compatibility (currently always 1) +- **Metadata constraints**: Only string keys and string values are supported in metadata +- **Backward compatible**: If no metadata is present, the payload is passed through unchanged + (no envelope is created) +- **Transparent to user code**: User functions never see the ``__dapr_meta__`` or ``__dapr_payload__`` + keys; they work with the original input types + +**Propagation flow:** + +1. **Client → Workflow**: Client interceptor sets metadata, runtime wraps before scheduling +2. **Workflow → Activity**: Outbound interceptor sets metadata, runtime wraps before activity call +3. **Workflow → Child Workflow**: Outbound interceptor sets metadata, runtime wraps before child call +4. **Runtime unwrap**: Before executing user code, runtime unwraps and exposes metadata via + ``ExecuteWorkflowRequest.metadata`` or ``ExecuteActivityRequest.metadata`` to interceptors + +**Size and content guidelines:** + +- Keep metadata small (typically < 1KB); avoid large values +- Use for cross-cutting concerns: trace IDs, tenant info, request IDs, correlation IDs +- Do NOT store sensitive data without encryption/redaction policies +- Consider enforcing size limits in custom interceptors if needed + +Minimal input guidance (SDK-facing) +----------------------------------- + +- Workflow input SHOULD be JSON serializable and a preferably a single dict carried under ``ExecuteWorkflowRequest.input``. Prefer a + single object over positional ``input`` to avoid shape ambiguity and ease future evolution. This is + a recommendation for consistency and versioning; the SDK accepts any JSON-serializable input type + (dict, list, or scalar) and preserves the original shape when unwrapping the envelope. + +- For contextual data, you can use "headers" (aliases for metadata) on the workflow context: + ``set_headers``/``get_headers`` behave the same as ``set_metadata``/``get_metadata`` and are + provided for familiarity with systems that use header terminology. ``continue_as_new`` also + supports ``carryover_headers`` as an alias to ``carryover_metadata``. +- If your app needs a tracing or correlation fallback, include a small ``trace_context`` dict in + your input envelope. Interceptors should restore from ``metadata`` first (see below), then + optionally fall back to this field when present. + +Example (generic): + +.. code-block:: json + + { + "schema_version": "your-app:workflow_input@v1", + "trace_context": { "trace_id": "...", "span_id": "..." }, + "payload": { } + } + +Determinism and safety +~~~~~~~~~~~~~~~~~~~~~~ + +- In workflows, read metadata and avoid non-deterministic operations inside interceptors. Do not + perform network I/O in orchestrators. +- Activities may read/modify metadata and perform I/O inside the activity function if desired. + +Metadata persistence lifecycle +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``ctx.set_metadata()`` attaches a string-only dict to the current workflow activation. The runtime + persists it by wrapping inputs in the envelope shown above. Set metadata before yielding or + returning from an activation to ensure it is durably recorded. +- ``continue_as_new``: metadata is not implicitly carried. Use + ``ctx.continue_as_new(new_input, carryover_metadata=True)`` to carry current metadata or provide a + dict to merge/override: ``carryover_metadata={"key": "value"}``. +- Child workflows and activities: metadata is propagated when set on the outbound call input by + interceptors. If you maintain a baseline via ``ctx.set_metadata(...)``, your + ``WorkflowOutboundInterceptor`` can merge it into call-specific metadata. + +Tracing interceptors (example) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can implement tracing as interceptors that stamp/propagate IDs in ``metadata`` and suppress +spans during replay. A minimal sketch: + +.. code-block:: python + + from typing import Any, Callable + from dapr.ext.workflow import ( + BaseClientInterceptor, BaseWorkflowOutboundInterceptor, BaseRuntimeInterceptor, + WorkflowRuntime, DaprWorkflowClient, + ScheduleWorkflowRequest, CallActivityRequest, CallChildWorkflowRequest, + ExecuteWorkflowRequest, ExecuteActivityRequest, + ) + + TRACE_ID_KEY = 'otel.trace_id' + + class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(replace(input, metadata=md)) + + class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def call_activity(self, input: CallActivityRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + activity_name=input.activity_name, + input=input.input, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + def call_child_workflow(self, input: CallChildWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + workflow_name=input.workflow_name, + input=input.input, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + + class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + if not input.ctx.is_replaying: + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start workflow span here + return next(input) + def execute_activity(self, input: ExecuteActivityRequest, next): + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start activity span here + return next(input) + + rt = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor()], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(lambda: 'trace-123')], + ) + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(lambda: 'trace-123')]) + +See the full runnable example in ``ext/dapr-ext-workflow/examples/tracing_interceptors_example.py``. + +Recommended tracing restoration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Restore tracing from ``ExecuteWorkflowRequest.metadata`` first (e.g., a key like ``otel.trace_id``) + to preserve determinism and cross-activation continuity without touching user payloads. +- If no tracing metadata is present, optionally fall back to ``input.trace_context`` in your + application-defined input envelope. +- Suppress workflow spans during replay by checking ``input.ctx.is_replaying`` in runtime + interceptors. + +Execution info (minimal) and context properties +----------------------------------------------- + +``execution_info`` is minimal and only includes the durable ``inbound_metadata`` that was +propagated into this activation. Use context properties directly for all engine fields: + +- Manage outbound propagation via ``ctx.set_metadata(...)`` / ``ctx.get_metadata()``. The runtime + persists and propagates these values through the metadata envelope. + +Example: + +.. code-block:: python + + # In a workflow function + inbound = ctx.execution_info.inbound_metadata if ctx.execution_info else None + # Prepare outbound propagation + baseline = ctx.get_metadata() or {} + ctx.set_metadata({**baseline, 'tenant': 'acme'}) + +Notes +~~~~~ + +- User functions never see the envelope keys; they get the same input as before. +- Only string keys/values should be stored in headers/metadata; enforce size limits and redaction + policies as needed. +- The workflow context provides deterministic fields such as ``ctx.workflow_name``, + ``ctx.instance_id``, ``ctx.is_replaying``, and ``ctx.current_utc_datetime``. + Note that ``ctx.execution_info`` only contains ``inbound_metadata``; use context properties + directly for engine-provided fields. +- Interceptors are synchronous and must not perform I/O in orchestrators. Activities may perform + I/O inside the user function; interceptor code should remain fast and replay-safe. +- Client interceptors are applied when calling ``DaprWorkflowClient.schedule_new_workflow(...)`` and + when orchestrators call ``ctx.call_activity(...)`` or ``ctx.call_child_workflow(...)``. + + +Best-effort sandbox +~~~~~~~~~~~~~~~~~~~ + +Opt-in scoped compatibility mode maps ``asyncio.sleep``, ``random``, ``uuid.uuid4``, and ``time.time`` to deterministic equivalents during workflow execution. Use ``sandbox_mode="best_effort"`` or ``"strict"`` when registering async workflows. Strict mode blocks ``asyncio.create_task`` in orchestrators. + +Examples +~~~~~~~~ + +See ``examples/workflow-async/`` for complete examples: + +- ``simple.py`` - Comprehensive example with activities, child workflows, retry policies, and external events +- ``task_chaining.py`` - Sequential activity calls +- ``child_workflow.py`` - Parent/child workflow patterns +- ``fan_out_fan_in.py`` - Parallel activity execution +- ``human_approval.py`` - External event handling +- ``async_http_activity.py`` - Async activities with HTTP requests +- ``context_interceptors_example.py`` - Context propagation using interceptors + +Async Activities +~~~~~~~~~~~~~~~~ + +Activities can be either synchronous or asynchronous functions. Async activities are useful for I/O-bound operations like HTTP requests, database queries, or file operations: + +.. code-block:: python + + from durabletask.task import ActivityContext + + # Synchronous activity + def sync_activity(ctx: ActivityContext, data: str) -> str: + return data.upper() + + # Asynchronous activity + async def async_activity(ctx: ActivityContext, data: str) -> str: + # Perform async I/O operations + async with aiohttp.ClientSession() as session: + async with session.get(f"https://api.example.com/{data}") as response: + result = await response.json() + return result + +Both sync and async activities are registered the same way: + +.. code-block:: python + + worker.add_activity(sync_activity) + worker.add_activity(async_activity) + +Orchestrators call them identically regardless of whether they're sync or async - the SDK handles the execution automatically. + +Determinism and semantics +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``when_any`` losers: the first-completer result is returned; non-winning awaitables are ignored deterministically (no additional commands are emitted by the orchestrator for cancellation). This ensures replay stability. Integration behavior with the sidecar is subject to the Durable Task scheduler; the orchestrator does not actively cancel losers. +- Suspension and termination: when an instance is suspended, only new external events are buffered while replay continues to reconstruct state; async orchestrators can inspect ``ctx.is_suspended`` if exposed by the runtime. Termination completes the orchestrator with TERMINATED status and does not raise into the coroutine. End-to-end confirmation requires running against a sidecar; unit tests in this repo do not start a sidecar. + +Async patterns +~~~~~~~~~~~~~~ + +- Activities + + - Call: ``await ctx.call_activity(activity_fn, input=..., retry_policy=...)`` + - Activity functions can be ``def`` or ``async def``. When ``async def`` is used, the runtime awaits them. + +- Timers + + - Create a durable timer: ``await ctx.create_timer(seconds|timedelta)`` + +- External events + + - Wait: ``await ctx.wait_for_external_event(name)`` + - Raise (from client): ``DaprWorkflowClient.raise_workflow_event(instance_id, name, data)`` + +- Concurrency + + - All: ``results = await ctx.when_all([ ...awaitables... ])`` + - Any: ``first = await ctx.when_any([ ...awaitables... ])`` (non-winning awaitables are ignored deterministically) + +- Child workflows + + - Call: ``await ctx.call_child_workflow(workflow_fn, input=..., retry_policy=...)`` + +- Deterministic utilities + + - ``ctx.now()`` returns orchestration time from history + - ``ctx.random()`` returns a deterministic PRNG + - ``ctx.uuid4()`` returns a PRNG-derived deterministic UUID + +Runtime compatibility +--------------------- + +- ``ctx.is_suspended`` is surfaced if provided by the underlying runtime/context version; behavior may vary by Durable Task build. Integration tests that validate suspension semantics are gated behind a sidecar harness. + +when_any losers diagnostics (integration) +----------------------------------------- + +- When the sidecar exposes command diagnostics, you can assert only a single command set is emitted for a ``when_any`` (the orchestrator completes after the first winner without emitting cancels). Until then, unit tests assert single-yield behavior and README documents the expected semantics. + +Notes +----- + +- Orchestrators authored as ``async def`` are not driven by a global event loop you start. The Durable Task worker drives them via a coroutine-to-generator bridge; do not call ``asyncio.run`` around orchestrators. +- Use ``WorkflowRuntime.workflow`` with an ``async def`` (auto-detected) or ``WorkflowRuntime.async_workflow`` to register async orchestrators. + +Why async without an event loop? +-------------------------------- + +- Each ``await`` in an async orchestrator corresponds to a deterministic Durable Task decision (activity, timer, external event, ``when_all/any``). The worker advances the coroutine by sending results/exceptions back in, preserving replay and ordering. +- This gives you the readability and structure of ``async/await`` while enforcing workflow determinism (no ad-hoc I/O in orchestrators; all I/O happens in activities). +- The pattern follows other workflow engines (e.g., Durable Functions/Temporal): async authoring for clarity, runtime-driven scheduling for correctness. + References ---------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py index dd2d45b75..75d1c6211 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py @@ -14,8 +14,25 @@ """ # Import your main classes here +from dapr.ext.workflow.aio import AsyncWorkflowContext from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, when_all, when_any +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ClientInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, +) from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name @@ -25,6 +42,7 @@ 'WorkflowRuntime', 'DaprWorkflowClient', 'DaprWorkflowContext', + 'AsyncWorkflowContext', 'WorkflowActivityContext', 'WorkflowState', 'WorkflowStatus', @@ -32,4 +50,20 @@ 'when_any', 'alternate_name', 'RetryPolicy', + # interceptors + 'ClientInterceptor', + 'BaseClientInterceptor', + 'WorkflowOutboundInterceptor', + 'BaseWorkflowOutboundInterceptor', + 'RuntimeInterceptor', + 'BaseRuntimeInterceptor', + 'ScheduleWorkflowRequest', + 'CallChildWorkflowRequest', + 'CallActivityRequest', + 'ExecuteWorkflowRequest', + 'ExecuteActivityRequest', + 'compose_workflow_outbound_chain', + 'compose_runtime_chain', + 'WorkflowExecutionInfo', + 'ActivityExecutionInfo', ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py index ceb8672be..02b1ea412 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py @@ -14,7 +14,18 @@ """ from .dapr_workflow_client import DaprWorkflowClient +from .async_context import AsyncWorkflowContext +from .async_driver import CoroutineOrchestratorRunner +from .awaitables import ActivityAwaitable, ExternalEventAwaitable, SleepAwaitable, SubOrchestratorAwaitable, WhenAnyAwaitable, WhenAllAwaitable __all__ = [ + 'ActivityAwaitable', + 'AsyncWorkflowContext', + 'CoroutineOrchestratorRunner', + 'ExternalEventAwaitable', 'DaprWorkflowClient', + 'SleepAwaitable', + 'SubOrchestratorAwaitable', + 'WhenAnyAwaitable', + 'WhenAllAwaitable', ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py new file mode 100644 index 000000000..56fecb100 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py @@ -0,0 +1,160 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Awaitable, Callable, Sequence + +from durabletask import task +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) + +from .awaitables import ( + ActivityAwaitable, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async workflow context that exposes deterministic awaitables for activities, timers, +external events, and concurrency, along with deterministic utilities. +""" + + +class AsyncWorkflowContext(DeterministicContextMixin): + def __init__(self, base_ctx: task.OrchestrationContext): + self._base_ctx = base_ctx + # Initialize metadata from base context if available + base_metadata = None + if hasattr(base_ctx, 'get_metadata') and callable(base_ctx.get_metadata): + base_metadata = base_ctx.get_metadata() + self._metadata: dict[str, str] | None = base_metadata + + # Core workflow metadata parity with sync context + @property + def instance_id(self) -> str: + return self._base_ctx.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._base_ctx.current_utc_datetime + + # Activities & Sub-orchestrations + def call_activity( + self, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return ActivityAwaitable( + self._base_ctx, activity_fn, input=input, retry_policy=retry_policy, metadata=metadata + ) + + def call_child_workflow( + self, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return SubOrchestratorAwaitable( + self._base_ctx, + workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + metadata=metadata, + ) + + @property + def is_replaying(self) -> bool: + return self._base_ctx.is_replaying + + # Timers & Events + def create_timer(self, fire_at: int | float | timedelta | datetime) -> Awaitable[None]: + # If float provided, interpret as seconds + if isinstance(fire_at, (int, float)): + fire_at = timedelta(seconds=float(fire_at)) + return SleepAwaitable(self._base_ctx, fire_at) + + def wait_for_external_event(self, name: str) -> Awaitable[Any]: + return ExternalEventAwaitable(self._base_ctx, name) + + # Concurrency + def when_all(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[list[Any]]: + return WhenAllAwaitable(awaitables) + + def when_any(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[Any]: + return WhenAnyAwaitable(awaitables) + + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + @property + def is_suspended(self) -> bool: + # Placeholder; will be wired when Durable Task exposes this state in context + return self._base_ctx.is_suspended + + # Pass-throughs for completeness + def set_custom_status(self, custom_status: str) -> None: + if hasattr(self._base_ctx, 'set_custom_status'): + self._base_ctx.set_custom_status(custom_status) + + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + # Try extended signature; fall back to minimal for older fakes/contexts + try: + self._base_ctx.continue_as_new( + new_input, save_events=save_events, carryover_metadata=effective_carryover + ) + except TypeError: + self._base_ctx.continue_as_new(new_input, save_events=save_events) + + # Metadata parity + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + # Sync with base context if it supports metadata + if hasattr(self._base_ctx, 'set_metadata') and callable(self._base_ctx.set_metadata): + self._base_ctx.set_metadata(self._metadata) + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + + # Execution info parity - use our own managed _execution_info attribute + @property + def execution_info(self): # type: ignore[override] + return getattr(self._base_ctx, '_execution_info', None) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py new file mode 100644 index 000000000..06f2652ac --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py @@ -0,0 +1,21 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from durabletask.aio.driver import ( # type: ignore[import-not-found] + CoroutineOrchestratorRunner as _DTCoroutineOrchestratorRunner, +) + +# Re-export durabletask's CoroutineOrchestratorRunner +CoroutineOrchestratorRunner = _DTCoroutineOrchestratorRunner diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py new file mode 100644 index 000000000..e1ac079d4 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py @@ -0,0 +1,110 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Callable + +from durabletask import task +from durabletask.aio.awaitables import ( + AwaitableBase as _BaseAwaitable, # type: ignore[import-not-found] +) +from durabletask.aio.awaitables import ( + ExternalEventAwaitable as _DTExternalEventAwaitable, +) +from durabletask.aio.awaitables import ( + SleepAwaitable as _DTSleepAwaitable, +) +from durabletask.aio.awaitables import ( + WhenAllAwaitable as _DTWhenAllAwaitable, +) +from durabletask.aio.awaitables import ( + WhenAnyAwaitable as _DTWhenAnyAwaitable, +) + +AwaitableBase = _BaseAwaitable + + +class ActivityAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + app_id: str | None = None, + ): + self._ctx = ctx + self._activity_fn = activity_fn + self._input = input + self._retry_policy = retry_policy + self._metadata = metadata + self._app_id = app_id + + def _to_task(self) -> task.Task: + return self._ctx.call_activity( + self._activity_fn, + input=self._input, + retry_policy=self._retry_policy, + metadata=self._metadata, + app_id=self._app_id, + ) + + +class SubOrchestratorAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + app_id: str | None = None, + ): + self._ctx = ctx + self._workflow_fn = workflow_fn + self._input = input + self._instance_id = instance_id + self._retry_policy = retry_policy + self._metadata = metadata + self._app_id = app_id + + def _to_task(self) -> task.Task: + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + retry_policy=self._retry_policy, + metadata=self._metadata, + app_id=self._app_id, + ) + + +class SleepAwaitable(_DTSleepAwaitable): + pass + + +class ExternalEventAwaitable(_DTExternalEventAwaitable): + pass + + +class WhenAllAwaitable(_DTWhenAllAwaitable): + pass + + +class WhenAnyAwaitable(_DTWhenAnyAwaitable): + pass diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 461bfd43a..0faa64a85 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -19,6 +19,12 @@ from typing import Any, Optional, TypeVar import durabletask.internal.orchestrator_service_pb2 as pb +from dapr.ext.workflow.interceptors import ( + ClientInterceptor, + ScheduleWorkflowRequest, + compose_client_chain, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_context import Workflow @@ -51,6 +57,8 @@ def __init__( host: Optional[str] = None, port: Optional[str] = None, logger_options: Optional[LoggerOptions] = None, + *, + interceptors: list[ClientInterceptor] | None = None, ): address = getAddress(host, port) @@ -61,18 +69,31 @@ def __init__( self._logger = Logger('DaprWorkflowClient', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) options = self._logger.get_options() + # Optional gRPC channel options (keepalive, retry policy) via helpers + # channel_options = build_grpc_channel_options() + + # Construct base kwargs for TaskHubGrpcClient + base_kwargs = { + 'host_address': uri.endpoint, + 'metadata': metadata, + 'secure_channel': uri.tls, + 'log_handler': options.log_handler, + 'log_formatter': options.log_formatter, + } + + # Initialize TaskHubGrpcClient (DurableTask supports options) self.__obj = client.TaskHubGrpcClient( - host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter, + **base_kwargs, + # channel_options=channel_options, ) + # Interceptors + self._client_interceptors: list[ClientInterceptor] = list(interceptors or []) + def schedule_new_workflow( self, workflow: Workflow, @@ -81,6 +102,7 @@ def schedule_new_workflow( instance_id: Optional[str] = None, start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + metadata: dict[str, str] | None = None, ) -> str: """Schedules a new workflow instance for execution. @@ -95,25 +117,39 @@ def schedule_new_workflow( be scheduled immediately. reuse_id_policy: Optional policy to reuse the workflow id when there is a conflict with an existing workflow instance. + metadata (dict[str, str] | None): Optional dictionary of key-value pairs + to be included as metadata/headers for the workflow. Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): + wf_name = ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + + # Build interceptor chain around schedule call + def terminal(term_req: ScheduleWorkflowRequest) -> str: + payload = wrap_payload_with_metadata(term_req.input, term_req.metadata) return self.__obj.schedule_new_orchestration( - workflow.__dict__['_dapr_alternate_name'], - input=input, - instance_id=instance_id, - start_at=start_at, - reuse_id_policy=reuse_id_policy, + term_req.workflow_name, + input=payload, + instance_id=term_req.instance_id, + start_at=term_req.start_at, + reuse_id_policy=term_req.reuse_id_policy, ) - return self.__obj.schedule_new_orchestration( - workflow.__name__, + + chain = compose_client_chain(self._client_interceptors, terminal) + schedule_req = ScheduleWorkflowRequest( + workflow_name=wf_name, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, + metadata=metadata, ) + return chain(schedule_req) def get_workflow_state( self, instance_id: str, *, fetch_payloads: bool = True diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index 714def3f2..b37ba6d92 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,28 +11,61 @@ limitations under the License. """ +import enum from datetime import datetime, timedelta from typing import Any, Callable, List, Optional, TypeVar, Union +from dapr.ext.workflow.execution_info import WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import unwrap_payload_with_metadata, wrap_payload_with_metadata from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow, WorkflowContext from durabletask import task +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') -class DaprWorkflowContext(WorkflowContext): - """DaprWorkflowContext that provides proxy access to internal OrchestrationContext instance.""" +class Handlers(enum.Enum): + CALL_ACTIVITY = 'call_activity' + CALL_CHILD_WORKFLOW = 'call_child_workflow' + CONTINUE_AS_NEW = 'continue_as_new' + + +class DaprWorkflowContext(WorkflowContext, DeterministicContextMixin): + """Workflow context wrapper with deterministic utilities and metadata helpers. + + Purpose + ------- + - Proxy to the underlying ``durabletask.task.OrchestrationContext`` (engine fields like + ``trace_parent``, ``orchestration_span_id``, and ``workflow_attempt`` pass through). + - Provide SDK-level helpers for durable metadata propagation via interceptors. + - Expose ``execution_info`` as a per-activation snapshot complementing live properties. + + Tips + ---- + - Use ``ctx.get_metadata()/set_metadata()`` to manage outbound propagation. + - Use ``ctx.execution_info.inbound_metadata`` to inspect what arrived on this activation. + - Prefer engine-backed properties for tracing/attempts when available (not yet available in dapr sidecar); fall back to + metadata only for app-specific context. + """ def __init__( - self, ctx: task.OrchestrationContext, logger_options: Optional[LoggerOptions] = None + self, + ctx: task.OrchestrationContext, + logger_options: Optional[LoggerOptions] = None, + *, + outbound_handlers: dict[Handlers, Any] | None = None, ): self.__obj = ctx self._logger = Logger('DaprWorkflowContext', logger_options) + self._outbound_handlers = outbound_handlers or {} + self._metadata: dict[str, str] | None = None # provide proxy access to regular attributes of wrapped object def __getattr__(self, name): @@ -52,10 +83,34 @@ def current_utc_datetime(self) -> datetime: def is_replaying(self) -> bool: return self.__obj.is_replaying + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + # Metadata API + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + def set_custom_status(self, custom_status: str) -> None: self._logger.debug(f'{self.instance_id}: Setting custom status to {custom_status}') self.__obj.set_custom_status(custom_status) + # Execution info (populated by runtime when available) + @property + def execution_info(self) -> WorkflowExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: WorkflowExecutionInfo) -> None: + self._execution_info = info + def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: self._logger.debug(f'{self.instance_id}: Creating timer to fire at {fire_at} time') return self.__obj.create_timer(fire_at) @@ -67,8 +122,9 @@ def call_activity( input: TInput = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: - # Handle string activity names for cross-app scenarios + # Determine activity name if isinstance(activity, str): activity_name = activity if app_id is not None: @@ -77,24 +133,41 @@ def call_activity( ) else: self._logger.debug(f'{self.instance_id}: Creating activity {activity_name}') + else: + # Handle function activity objects + self._logger.debug(f'{self.instance_id}: Creating activity {activity.__name__}') + if hasattr(activity, '_dapr_alternate_name'): + activity_name = activity.__dict__['_dapr_alternate_name'] + else: + # this case should ideally never happen + activity_name = activity.__name__ - if retry_policy is None: - return self.__obj.call_activity(activity=activity_name, input=input, app_id=app_id) - return self.__obj.call_activity( - activity=activity_name, input=input, retry_policy=retry_policy.obj, app_id=app_id + # Apply outbound interceptor transformations for ALL activity calls (string and function) + transformed_input: Any = input + modified_retry_policy = retry_policy + modified_app_id = app_id + if Handlers.CALL_ACTIVITY in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_ACTIVITY] + ): + result = self._outbound_handlers[Handlers.CALL_ACTIVITY]( + self, activity, input, retry_policy, app_id, metadata or self.get_metadata() ) + # Handle tuple return (input, retry_policy, app_id) or just input for backward compat + if isinstance(result, tuple) and len(result) == 3: + transformed_input, modified_retry_policy, modified_app_id = result + else: + transformed_input = result - # Handle function activity objects (original behavior) - self._logger.debug(f'{self.instance_id}: Creating activity {activity.__name__}') - if hasattr(activity, '_dapr_alternate_name'): - act = activity.__dict__['_dapr_alternate_name'] - else: - # this case should ideally never happen - act = activity.__name__ - if retry_policy is None: - return self.__obj.call_activity(activity=act, input=input, app_id=app_id) + # Make the actual call with potentially modified parameters + if modified_retry_policy is None: + return self.__obj.call_activity( + activity=activity_name, input=transformed_input, app_id=modified_app_id + ) return self.__obj.call_activity( - activity=act, input=input, retry_policy=retry_policy.obj, app_id=app_id + activity=activity_name, + input=transformed_input, + retry_policy=modified_retry_policy.obj, + app_id=modified_app_id, ) def call_child_workflow( @@ -105,56 +178,130 @@ def call_child_workflow( instance_id: Optional[str] = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: - # Handle string workflow names for cross-app scenarios - if isinstance(workflow, str): + # Determine if this is a string workflow name or function workflow + is_string_workflow = isinstance(workflow, str) + + if is_string_workflow: workflow_name = workflow self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow_name}') + workflow_callable = None + else: + # Handle function workflow objects + self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') - if retry_policy is None: - return self.__obj.call_sub_orchestrator( - workflow_name, input=input, instance_id=instance_id, app_id=app_id - ) - return self.__obj.call_sub_orchestrator( - workflow_name, - input=input, - instance_id=instance_id, - retry_policy=retry_policy.obj, - app_id=app_id, - ) + def wf(ctx: task.OrchestrationContext, inp: TInput): + dapr_wf_context = DaprWorkflowContext(ctx, self._logger.get_options()) + return workflow(dapr_wf_context, inp) - # Handle function workflow objects (original behavior) - self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') + # copy workflow name so durabletask.worker can find the orchestrator in its registry + if hasattr(workflow, '_dapr_alternate_name'): + wf.__name__ = workflow.__dict__['_dapr_alternate_name'] + else: + # this case should ideally never happen + wf.__name__ = workflow.__name__ - def wf(ctx: task.OrchestrationContext, inp: TInput): - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - return workflow(daprWfContext, inp) + workflow_callable = wf + workflow_name = wf.__name__ + + # Apply outbound interceptor transformations for ALL child workflow calls (string and function) + transformed_input: Any = input + modified_instance_id = instance_id + modified_retry_policy = retry_policy + modified_app_id = app_id + if Handlers.CALL_CHILD_WORKFLOW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW] + ): + result = self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW]( + self, + workflow, + input, + instance_id, + retry_policy, + app_id, + metadata or self.get_metadata(), + ) + # Handle tuple return (input, instance_id, retry_policy, app_id) or just input for backward compat + if isinstance(result, tuple) and len(result) == 4: + transformed_input, modified_instance_id, modified_retry_policy, modified_app_id = ( + result + ) + else: + transformed_input = result - # copy workflow name so durabletask.worker can find the orchestrator in its registry + # Make the actual call with potentially modified parameters + # For string workflows, use the workflow_name directly; for functions, use the wrapper + target = workflow_name if is_string_workflow else workflow_callable - if hasattr(workflow, '_dapr_alternate_name'): - wf.__name__ = workflow.__dict__['_dapr_alternate_name'] - else: - # this case should ideally never happen - wf.__name__ = workflow.__name__ - if retry_policy is None: + if modified_retry_policy is None: return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, app_id=app_id + target, + input=transformed_input, + instance_id=modified_instance_id, + app_id=modified_app_id, ) return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, retry_policy=retry_policy.obj, app_id=app_id + target, + input=transformed_input, + instance_id=modified_instance_id, + retry_policy=modified_retry_policy.obj, + app_id=modified_app_id, ) def wait_for_external_event(self, name: str) -> task.Task: self._logger.debug(f'{self.instance_id}: Waiting for external event {name}') return self.__obj.wait_for_external_event(name) - def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool = False, + metadata: dict[str, str] | None = None, + ) -> None: + """ + Continue the workflow execution with new inputs and optional metadata or headers. + + This method allows restarting the workflow execution with new input parameters, + while optionally preserving workflow events, metadata, and/or headers. It also + integrates with workflow interceptors if configured, enabling custom modification + of inputs and associated metadata before continuation. + + Args: + new_input: Any new input to pass to the workflow upon continuation. + save_events (bool): Indicates whether to preserve the event history of the + workflow execution. Defaults to False. + carryover_metadata bool: If True, carries + over metadata from the current execution. + metadata dict[str, str] | None: If a dictionary is provided, it + will be added to the current metadata. If carryover_metadata is True, + the contents of the dictionary will be merged with the current metadata. + """ self._logger.debug(f'{self.instance_id}: Continuing as new') - self.__obj.continue_as_new(new_input, save_events=save_events) + # Allow workflow outbound interceptors (wired via runtime) to modify payload/metadata + transformed_input: Any = new_input + if Handlers.CONTINUE_AS_NEW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CONTINUE_AS_NEW] + ): + transformed_input = self._outbound_handlers[Handlers.CONTINUE_AS_NEW]( + self, new_input, self.get_metadata() + ) + + # Merge/carry metadata if requested, unwrapping any envelope produced by interceptors + payload, base_md = unwrap_payload_with_metadata(transformed_input) + # Start with current context metadata; then layer any interceptor-provided metadata on top + current_md = self.get_metadata() or {} + effective_md = (current_md | (base_md or {})) if carryover_metadata else {} + if metadata is not None: + effective_md = effective_md | metadata + + payload = wrap_payload_with_metadata(payload, effective_md) + self.__obj.continue_as_new(payload, save_events=save_events) -def when_all(tasks: List[task.Task[T]]) -> task.WhenAllTask[T]: +def when_all(tasks: List[task.Task]) -> task.WhenAllTask: """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail.""" return task.when_all(tasks) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py new file mode 100644 index 000000000..d33a02c60 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py @@ -0,0 +1,27 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Backward-compatible shim: import deterministic utilities from durabletask +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, + deterministic_random, + deterministic_uuid4, +) + +__all__ = [ + 'DeterministicContextMixin', + 'deterministic_random', + 'deterministic_uuid4', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py new file mode 100644 index 000000000..0aacd7106 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py @@ -0,0 +1,49 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +""" +Minimal, deterministic snapshots of inbound durable metadata. + +Rationale +--------- + +Execution info previously mirrored many engine fields (IDs, tracing, attempts) already +available on the workflow/activity contexts. To remove redundancy and simplify usage, the +execution info types now only capture the durable ``inbound_metadata`` that was actually +propagated into this activation. Use context properties directly for engine fields. +""" + + +@dataclass +class WorkflowExecutionInfo: + """Per-activation snapshot for workflows. + + Only includes ``inbound_metadata`` that arrived with this activation. + """ + + inbound_metadata: dict[str, str] + + +@dataclass +class ActivityExecutionInfo: + """Per-activation snapshot for activities. + + Only includes ``inbound_metadata`` that arrived with this activity invocation. + """ + + inbound_metadata: dict[str, str] + activity_name: str diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py new file mode 100644 index 000000000..9823d2567 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py @@ -0,0 +1,513 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Protocol, TypeVar + +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_context import WorkflowContext + +# Type variables for generic interceptor payload typing +TInput = TypeVar('TInput') +TWorkflowInput = TypeVar('TWorkflowInput') +TActivityInput = TypeVar('TActivityInput') + +""" +Interceptor interfaces and chain utilities for the Dapr Workflow SDK. + +Providing a single enter/exit around calls. + +IMPORTANT: Generator wrappers for async workflows +-------------------------------------------------- +When writing runtime interceptors that touch workflow execution, be careful with generator +handling. If an interceptor obtains a workflow generator from user code (e.g., an async +orchestrator adapted into a generator) it must not manually iterate it using a for-loop +and yield the produced items. Doing so breaks send()/throw() propagation back into the +inner generator, which can cause resumed results from the durable runtime to be dropped +and appear as None to awaiters. + +Best practices: +- If the interceptor participates in composition and needs to return the generator, + return it directly (do not iterate it). +- If the interceptor must wrap the generator, always use "yield from inner_gen" so that + send()/throw() are forwarded correctly. + +Context managers with async workflows +-------------------------------------- +When using context managers (like ExitStack, logging contexts, or trace contexts) in an +interceptor for async workflows, be aware that calling `next(input)` returns a generator +object immediately, NOT the final result. The generator executes later when the durable +task runtime drives it. + +If you need a context manager to remain active during the workflow execution: + +**WRONG - Context exits before workflow runs:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + with setup_context(): + return next(input) # Returns generator, context exits immediately! + +**CORRECT - Context stays active throughout execution:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with setup_context(): + gen = next(input) + result = yield from gen # Keep context alive while generator executes + return result # MUST return the result to propagate workflow output + return wrapper() + +For more complex scenarios with ExitStack or async context managers, wrap the generator +with `yield from` to ensure your context spans the entire workflow execution, including +all replay and continuation events. + +Example with ExitStack: + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with ExitStack() as stack: + # Set up contexts (trace, logging, etc.) + stack.enter_context(trace_context(...)) + stack.enter_context(logging_context(...)) + + # Get the generator from the next interceptor/handler + gen = next(input) + + # Keep contexts alive while generator executes + result = yield from gen + return result # MUST return the result to propagate workflow output + return wrapper() + +This pattern ensures your context manager remains active during: +- Initial workflow execution +- Replays from durable state +- Continuation after awaits +- Activity calls and child workflow invocations + +CRITICAL: When using `yield from gen`, you MUST capture and return the result explicitly. +Without `return result`, the wrapper generator will return None and the workflow output +will be lost (serialized_output will be null). +""" + + +# Context metadata propagation +# ---------------------------- +# "metadata" is a durable, string-only map. It is serialized on the wire and propagates across +# boundaries (client → runtime → activity/child), surviving replays/retries. Use it when downstream +# components must observe the value. In-process ephemeral state should be handled within interceptors +# without attempting to propagate across process boundaries. + + +# ------------------------------ +# Client-side interceptor surface +# ------------------------------ + + +@dataclass +class ScheduleWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + start_at: Any | None + reuse_id_policy: ( + Any | None + ) # should be used to handle the case where you want to schedule a workflow with an ID that might already exist. + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallChildWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + retry_policy: Any | None + app_id: str | None + # Optional workflow context for outbound calls made inside workflows (not serialized and propagated across boundaries) + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class ContinueAsNewRequest(Generic[TInput]): + input: TInput + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallActivityRequest(Generic[TInput]): + activity_name: str + input: TInput + retry_policy: Any | None + app_id: str | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +class ClientInterceptor(Protocol, Generic[TInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: ... + + +# ------------------------------- +# Runtime-side interceptor surface +# ------------------------------- + + +@dataclass +class ExecuteWorkflowRequest(Generic[TInput]): + ctx: WorkflowContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +@dataclass +class ExecuteActivityRequest(Generic[TInput]): + ctx: WorkflowActivityContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +class RuntimeInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +# ------------------------------ +# Convenience base classes (devex) +# ------------------------------ + + +class BaseClientInterceptor(Generic[TInput]): + """Subclass this to get method name completion and safe defaults. + + Override any of the methods to customize behavior. By default, these + methods simply call `next` unchanged. + """ + + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + # No workflow-outbound methods here; use WorkflowOutboundInterceptor for those + + +class BaseRuntimeInterceptor(Generic[TWorkflowInput, TActivityInput]): + """Subclass this to get method name completion and safe defaults.""" + + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + +# ------------------------------ +# Helper: chain composition +# ------------------------------ + + +def compose_client_chain( + interceptors: list[ClientInterceptor], terminal: Callable[[Any], Any] +) -> Callable[[Any], Any]: + """Compose client interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., scheduling the workflow) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: ClientInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ScheduleWorkflowRequest): + return curr_icpt.schedule_new_workflow(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Workflow outbound interceptor surface +# ------------------------------ + + +class WorkflowOutboundInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +class BaseWorkflowOutboundInterceptor(Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: + return next(input) + + +# ------------------------------ +# Backward-compat typing aliases +# ------------------------------ + + +def compose_workflow_outbound_chain( + interceptors: list[WorkflowOutboundInterceptor], + terminal: Callable[[Any], Any], +) -> Callable[[Any], Any]: + """Compose workflow outbound interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., preparing outbound call args) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: WorkflowOutboundInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + # Dispatch to the appropriate outbound method on the interceptor + if isinstance(input, CallActivityRequest): + return curr_icpt.call_activity(input, nxt) + if isinstance(input, CallChildWorkflowRequest): + return curr_icpt.call_child_workflow(input, nxt) + if isinstance(input, ContinueAsNewRequest): + return curr_icpt.continue_as_new(input, nxt) + # Fallback to next if input type unknown + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Helper: envelope for durable metadata +# ------------------------------ + +_META_KEY = '__dapr_meta__' +_META_VERSION = 1 +_PAYLOAD_KEY = '__dapr_payload__' + + +def wrap_payload_with_metadata(payload: Any, metadata: dict[str, str] | None) -> Any: + """Wrap payload in an envelope with metadata for durable persistence. + + The envelope structure allows metadata to be propagated across workflow boundaries + (client → workflow → activity → child workflow) and persisted durably alongside the + payload. This metadata survives replays, retries, and continues-as-new operations. + + Envelope structure (when metadata is present): + ```python + { + '__dapr_meta__': { + 'v': 1, # Version for future compatibility + 'metadata': { + # String key-value pairs only + 'tenant': 'acme-corp', + 'request_id': 'req-12345', + 'otel.trace_id': 'abc123...', + # ... other metadata + } + }, + '__dapr_payload__': + } + ``` + + When serialized to JSON and stored in Dapr state/history: + ```json + { + "__dapr_meta__": { + "v": 1, + "metadata": { + "tenant": "acme-corp", + "request_id": "req-12345" + } + }, + "__dapr_payload__": {"user_input": "some data"} + } + ``` + + Usage: + - **Client scheduling**: Metadata set by client interceptors is wrapped before + scheduling workflows + - **Activity calls**: Metadata is wrapped before calling activities from workflows + - **Child workflows**: Metadata is wrapped before calling child workflows + - **Continue-as-new**: Metadata can be carried over or reset + + Args: + payload: The actual data to be passed (can be any JSON-serializable type) + metadata: Optional string-only dictionary with cross-cutting concerns + (e.g., trace IDs, tenant info, request IDs) + + Returns: + If metadata is provided and non-empty, returns the envelope dict. + Otherwise returns payload unchanged (backward compatible). + + Note: + - User code never sees the envelope; it's unwrapped before execution + - Only string keys/values should be stored in metadata + - Metadata should be kept small (avoid large values, no binary data) + - Consider size limits and PII redaction policies + """ + if metadata: + return { + _META_KEY: { + 'v': _META_VERSION, + 'metadata': metadata, + }, + _PAYLOAD_KEY: payload, + } + return payload + + +def unwrap_payload_with_metadata(obj: Any) -> tuple[Any, dict[str, str] | None]: + """Extract payload and metadata from envelope if present. + + This function is called by the runtime before executing workflows/activities to + separate the user payload from the metadata. The payload is passed to user code, + while metadata is made available through the execution context. + + Args: + obj: The potentially-wrapped input (may be an envelope or raw payload) + + Returns: + A tuple of (payload, metadata_dict_or_none): + - If obj is an envelope: (extracted_payload, extracted_metadata) + - If obj is not an envelope: (obj, None) + + Example: + ```python + # Envelope case + envelope = { + '__dapr_meta__': {'v': 1, 'metadata': {'tenant': 'acme'}}, + '__dapr_payload__': {'x': 1} + } + payload, metadata = unwrap_payload_with_metadata(envelope) + # payload = {'x': 1} + # metadata = {'tenant': 'acme'} + + # Non-envelope case (backward compatibility) + raw = {'x': 1} + payload, metadata = unwrap_payload_with_metadata(raw) + # payload = {'x': 1} + # metadata = None + ``` + + Note: + - Robust error handling: any exception during unwrapping treats input as raw payload + - Validates envelope structure (must have both __dapr_meta__ and __dapr_payload__) + - Returns None for metadata if envelope is malformed or metadata is not a dict + """ + try: + if isinstance(obj, dict) and _META_KEY in obj and _PAYLOAD_KEY in obj: + meta = obj.get(_META_KEY) or {} + md = meta.get('metadata') if isinstance(meta, dict) else None + return obj.get(_PAYLOAD_KEY), md if isinstance(md, dict) else None + except Exception: + # Be robust: on any error, treat as raw payload + pass + return obj, None + + +def compose_runtime_chain( + interceptors: list[RuntimeInterceptor], final_handler: Callable[[Any], Any] +): + """Compose runtime interceptors into a single callable (synchronous). + + The ``final_handler`` callable is the final handler invoked after all interceptors; it + performs the core operation (e.g., calling user workflow/activity or returning a + workflow generator) when the chain ends. + """ + next_fn = final_handler + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: RuntimeInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ExecuteWorkflowRequest): + return curr_icpt.execute_workflow(input, nxt) + if isinstance(input, ExecuteActivityRequest): + return curr_icpt.execute_activity(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py index 331ad6c2c..bdf5068da 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py @@ -17,6 +17,7 @@ from typing import Callable, TypeVar +from dapr.ext.workflow.execution_info import ActivityExecutionInfo from durabletask import task T = TypeVar('T') @@ -25,10 +26,18 @@ class WorkflowActivityContext: - """Defines properties and methods for task activity context objects.""" + """Wrapper for ``durabletask.task.ActivityContext`` with metadata helpers. + + Purpose + ------- + - Surface ``execution_info``: a per-activation snapshot that includes the + ``inbound_metadata`` actually received for this activity. + - Offer ``get_metadata()/set_metadata()`` for SDK-level durable metadata management. + """ def __init__(self, ctx: task.ActivityContext): self.__obj = ctx + self._metadata: dict[str, str] | None = None @property def workflow_id(self) -> str: @@ -43,6 +52,20 @@ def task_id(self) -> int: def get_inner_context(self) -> task.ActivityContext: return self.__obj + @property + def execution_info(self) -> ActivityExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: ActivityExecutionInfo) -> None: + self._execution_info = info + + # Metadata accessors (SDK-level; set by runtime inbound if available) + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + # Activities are simple functions that can be scheduled by workflows Activity = Callable[..., TOutput] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index 8453e16ef..e8c1e6406 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +15,9 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Optional, TypeVar, Union +from typing import Any, Callable, Generator, TypeVar, Union +from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import Activity from durabletask import task @@ -90,7 +89,7 @@ def set_custom_status(self, custom_status: str) -> None: pass @abstractmethod - def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: + def create_timer(self, fire_at: datetime | timedelta) -> task.Task: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -110,8 +109,9 @@ def call_activity( self, activity: Union[Activity[TOutput], str], *, - input: Optional[TInput] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + app_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule an activity for execution. @@ -123,6 +123,7 @@ def call_activity( The JSON-serializable input (or None) to pass to the activity. app_id: str | None The AppID that will execute the activity. + retry_policy: RetryPolicy | None Returns ------- @@ -136,9 +137,10 @@ def call_child_workflow( self, orchestrator: Union[Workflow[TOutput], str], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + instance_id: str | None = None, + app_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule child-workflow function for execution. @@ -153,6 +155,9 @@ def call_child_workflow( random UUID will be used. app_id: str The AppID that will execute the workflow. + retry_policy: RetryPolicy | None + Optional retry policy for the child-workflow. When provided, failures will be retried + according to the policy. Returns ------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 593e55c68..a2ceb1ab4 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -14,16 +14,32 @@ """ import inspect +import traceback from functools import wraps -from typing import Optional, Sequence, TypeVar, Union +from typing import Any, Awaitable, Callable, List, Optional, Sequence, TypeVar, Union import grpc -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, Handlers +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, + unwrap_payload_with_metadata, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow from durabletask import task, worker +from durabletask.aio.sandbox import SandboxMode from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER @@ -47,16 +63,19 @@ class WorkflowRuntime: def __init__( self, - host: Optional[str] = None, - port: Optional[str] = None, + host: str | None = None, + port: str | None = None, logger_options: Optional[LoggerOptions] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None, maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, + *, + runtime_interceptors: Optional[list[RuntimeInterceptor]] = None, + workflow_outbound_interceptors: Optional[list[WorkflowOutboundInterceptor]] = None, ): self._logger = Logger('WorkflowRuntime', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) address = getAddress(host, port) @@ -80,16 +99,148 @@ def __init__( maximum_thread_pool_workers=maximum_thread_pool_workers, ), ) + # Interceptors + self._runtime_interceptors: List[RuntimeInterceptor] = list(runtime_interceptors or []) + self._workflow_outbound_interceptors: List[WorkflowOutboundInterceptor] = list( + workflow_outbound_interceptors or [] + ) + + # Outbound helpers apply interceptors and wrap metadata; no built-in transformations. + def _apply_outbound_activity( + self, + ctx: Any, + activity: Callable[..., Any] | str, + input: Any, + retry_policy: Any | None, + app_id: str | None, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform CallActivityRequest + name = ( + activity + if isinstance(activity, str) + else ( + activity.__dict__['_dapr_alternate_name'] + if hasattr(activity, '_dapr_alternate_name') + else activity.__name__ + ) + ) + + def terminal(term_req: CallActivityRequest) -> CallActivityRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + # Use per-context default metadata when not provided + metadata = metadata or ctx.get_metadata() + act_req = CallActivityRequest( + activity_name=name, + input=input, + retry_policy=retry_policy, + app_id=app_id, + workflow_ctx=ctx, + metadata=metadata, + ) + out = chain(act_req) + if isinstance(out, CallActivityRequest): + return ( + wrap_payload_with_metadata(out.input, out.metadata), + out.retry_policy, + out.app_id, + ) + return input, retry_policy, app_id + + def _apply_outbound_child( + self, + ctx: Any, + workflow: Callable[..., Any] | str, + input: Any, + instance_id: str | None, + retry_policy: Any | None, + app_id: str | None, + metadata: dict[str, str] | None = None, + ): + name = ( + workflow + if isinstance(workflow, str) + else ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + ) + + def terminal(term_req: CallChildWorkflowRequest) -> CallChildWorkflowRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + child_req = CallChildWorkflowRequest( + workflow_name=name, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + app_id=app_id, + workflow_ctx=ctx, + metadata=metadata, + ) + out = chain(child_req) + if isinstance(out, CallChildWorkflowRequest): + return ( + wrap_payload_with_metadata(out.input, out.metadata), + out.instance_id, + out.retry_policy, + out.app_id, + ) + return input, instance_id, retry_policy, app_id + + def _apply_outbound_continue_as_new( + self, + ctx: Any, + new_input: Any, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform ContinueAsNewRequest + from dapr.ext.workflow.interceptors import ContinueAsNewRequest + + def terminal(term_req: ContinueAsNewRequest) -> ContinueAsNewRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + cnr = ContinueAsNewRequest(input=new_input, workflow_ctx=ctx, metadata=metadata) + out = chain(cnr) + if isinstance(out, ContinueAsNewRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return new_input def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): + # Seamlessly support async workflows using the existing API + if inspect.iscoroutinefunction(fn): + return self.register_async_workflow(fn, name=name) + self._logger.info(f"Registering workflow '{fn.__name__}' with runtime") - def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): - """Responsible to call Workflow function in orchestrationWrapper""" - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: - return fn(daprWfContext) - return fn(daprWfContext, inp) + def orchestration_wrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + dapr_wf_context = self._get_workflow_context(ctx, md) + + # Build interceptor chain; terminal calls the user function (generator or non-generator) + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + try: + return ( + fn(dapr_wf_context) + if exec_req.input is None + else fn(dapr_wf_context, exec_req.input) + ) + except Exception as exc: # log and re-raise to surface failure details + self._logger.error( + f"{ctx.instance_id}: workflow '{fn.__name__}' raised {type(exc).__name__}: {exc}\n{traceback.format_exc()}" + ) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=dapr_wf_context, input=payload, metadata=md)) if hasattr(fn, '_workflow_registered'): # whenever a workflow is registered, it has a _dapr_alternate_name attribute @@ -104,7 +255,7 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_orchestrator( - fn.__dict__['_dapr_alternate_name'], orchestrationWrapper + fn.__dict__['_dapr_alternate_name'], orchestration_wrapper ) fn.__dict__['_workflow_registered'] = True @@ -114,12 +265,96 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): """ self._logger.info(f"Registering activity '{fn.__name__}' with runtime") - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: - return fn(wfActivityContext) - return fn(wfActivityContext, inp) + # TODO: This dual-wrapper approach could be simplified if interceptors were moved + # to durabletask-python. Another option is to wait for potential merge of durabletask-python and python-sdk + # repos, which would eliminate the need for separate wrapper detection logic here. + # This would simplify the architecture by having a single execution path handle both + # sync and async activities with interceptor support built-in at the durabletask level. + + if inspect.iscoroutinefunction(fn): + # Create async wrapper for async activities + async def async_activity_wrapper( + ctx: task.ActivityContext, inp: Optional[TInput] = None + ): + """Async activity entrypoint wrapped by runtime interceptors.""" + wf_activity_context = WorkflowActivityContext(ctx) + payload, md = unwrap_payload_with_metadata(inp) + # Populate inbound metadata onto activity context + wf_activity_context.set_metadata(md or {}) + + # Populate execution info + try: + # Determine activity name (registered alternate name or function __name__) + act_name = getattr(fn, '_dapr_alternate_name', fn.__name__) + ainfo = ActivityExecutionInfo(inbound_metadata=md or {}, activity_name=act_name) + wf_activity_context._set_execution_info(ainfo) + except Exception: + pass + + # Execute the async activity BEFORE the interceptor chain + # This ensures interceptors see actual results, not coroutines + try: + if payload is None: + activity_result = await fn(wf_activity_context) + else: + activity_result = await fn(wf_activity_context, payload) + except Exception as exc: + # Log details for troubleshooting (metadata, error type) + self._logger.error( + f"{ctx.orchestration_id}:{ctx.task_id} activity '{fn.__name__}' failed with {type(exc).__name__}: {exc}" + ) + self._logger.error(traceback.format_exc()) + raise + + # Now pass the result through the interceptor chain + # Interceptors can log/transform the result but not wrap the async execution + def final_handler(exec_req: ExecuteActivityRequest) -> Any: + return activity_result + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain( + ExecuteActivityRequest(ctx=wf_activity_context, input=payload, metadata=md) + ) + + wrapper = async_activity_wrapper + else: + # Create sync wrapper for sync activities + def sync_activity_wrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): + """Activity entrypoint wrapped by runtime interceptors.""" + wf_activity_context = WorkflowActivityContext(ctx) + payload, md = unwrap_payload_with_metadata(inp) + # Populate inbound metadata onto activity context + wf_activity_context.set_metadata(md or {}) + + # Populate execution info + try: + # Determine activity name (registered alternate name or function __name__) + act_name = getattr(fn, '_dapr_alternate_name', fn.__name__) + ainfo = ActivityExecutionInfo(inbound_metadata=md or {}, activity_name=act_name) + wf_activity_context._set_execution_info(ainfo) + except Exception: + pass + + def final_handler(exec_req: ExecuteActivityRequest) -> Any: + try: + # Call sync activity + if exec_req.input is None: + return fn(wf_activity_context) + return fn(wf_activity_context, exec_req.input) + except Exception as exc: + # Log details for troubleshooting (metadata, error type) + self._logger.error( + f"{ctx.orchestration_id}:{ctx.task_id} activity '{fn.__name__}' failed with {type(exc).__name__}: {exc}" + ) + self._logger.error(traceback.format_exc()) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain( + ExecuteActivityRequest(ctx=wf_activity_context, input=payload, metadata=md) + ) + + wrapper = sync_activity_wrapper if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -133,18 +368,30 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): else: fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ - self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper - ) + self.__worker._registry.add_named_activity(fn.__dict__['_dapr_alternate_name'], wrapper) fn.__dict__['_activity_registered'] = True def start(self): """Starts the listening for work items on a background thread.""" self.__worker.start() + def __enter__(self): + self.start() + return self + def shutdown(self): """Stops the listening for work items on a background thread.""" - self.__worker.stop() + try: + self._logger.info('Stopping gRPC worker...') + self.__worker.stop() + self._logger.info('Worker shutdown completed') + except Exception as exc: # pragma: no cover + # DurableTask worker may emit CANCELLED warnings during local shutdown; not fatal + self._logger.warning(f'Worker stop encountered {type(exc).__name__}: {exc}') + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + return False def workflow(self, __fn: Workflow = None, *, name: Optional[str] = None): """Decorator to register a workflow function. @@ -174,7 +421,11 @@ def add(ctx, x: int, y: int) -> int: """ def wrapper(fn: Workflow): - self.register_workflow(fn, name=name) + # Auto-detect coroutine and delegate to async registration + if inspect.iscoroutinefunction(fn): + self.register_async_workflow(fn, name=name) + else: + self.register_workflow(fn, name=name) @wraps(fn) def innerfn(): @@ -194,6 +445,121 @@ def innerfn(): return wrapper + # Async orchestrator registration (additive) + def register_async_workflow( + self, + fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]], + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.BEST_EFFORT, + ) -> None: + """Register an async workflow function. + + The async workflow is wrapped by a coroutine-to-generator driver so it can be + executed by the Durable Task runtime alongside existing generator workflows. + + Args: + fn: The async workflow function, taking ``AsyncWorkflowContext`` and optional input. + name: Optional alternate name for registration. + sandbox_mode: Scoped compatibility patching mode. + """ + self._logger.info(f"Registering ASYNC workflow '{fn.__name__}' with runtime") + + if hasattr(fn, '_workflow_registered'): + alt_name = fn.__dict__['_dapr_alternate_name'] + raise ValueError(f'Workflow {fn.__name__} already registered as {alt_name}') + if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn._dapr_alternate_name + if name is not None: + m = f'Workflow {fn.__name__} already has an alternate name {alt_name}' + raise ValueError(m) + else: + fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + + runner = CoroutineOrchestratorRunner(fn, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, inp: Optional[Any] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + base_ctx = self._get_workflow_context(ctx, md) + + async_ctx = AsyncWorkflowContext(base_ctx) + + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + # Build the generator using the (potentially shaped) input from interceptors. + shaped_input = exec_req.input + return runner.to_generator(async_ctx, shaped_input) + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=async_ctx, input=payload, metadata=md)) + + self.__worker._registry.add_named_orchestrator( + fn.__dict__['_dapr_alternate_name'], generator_orchestrator + ) + fn.__dict__['_workflow_registered'] = True + + def _get_workflow_context( + self, ctx: task.OrchestrationContext, metadata: dict[str, str] | None = None + ) -> DaprWorkflowContext: + """Get the workflow context and execution info for the given orchestration context and metadata. + Execution info serves as a read-only snapshot of the workflow context. + + Args: + ctx: The orchestration context. + metadata: The metadata for the workflow. + + Returns: + The workflow context. + """ + base_ctx = DaprWorkflowContext( + ctx, + self._logger.get_options(), + outbound_handlers={ + Handlers.CALL_ACTIVITY: self._apply_outbound_activity, + Handlers.CALL_CHILD_WORKFLOW: self._apply_outbound_child, + Handlers.CONTINUE_AS_NEW: self._apply_outbound_continue_as_new, + }, + ) + # Populate minimal execution info (only inbound metadata) + info = WorkflowExecutionInfo(inbound_metadata=metadata or {}) + base_ctx._set_execution_info(info) + base_ctx.set_metadata(metadata or {}) + return base_ctx + + def async_workflow( + self, + __fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + """Decorator to register an async workflow function. + + Usage: + @runtime.async_workflow(name="my_wf") + async def my_wf(ctx: AsyncWorkflowContext, input): + ... + """ + + def wrapper(fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]]): + self.register_async_workflow(fn, name=name, sandbox_mode=sandbox_mode) + + @wraps(fn) + def innerfn(): + return fn + + if hasattr(fn, '_dapr_alternate_name'): + innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + else: + innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__signature__ = inspect.signature(fn) + return innerfn + + if __fn: + return wrapper(__fn) + + return wrapper + def activity(self, __fn: Activity = None, *, name: Optional[str] = None): """Decorator to register an activity function. diff --git a/ext/dapr-ext-workflow/pytest.ini b/ext/dapr-ext-workflow/pytest.ini new file mode 100644 index 000000000..d08eef800 --- /dev/null +++ b/ext/dapr-ext-workflow/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + e2e: marks tests as end-to-end integration tests (deselect with '-m "not e2e"') + diff --git a/ext/dapr-ext-workflow/setup.cfg b/ext/dapr-ext-workflow/setup.cfg index fdf8bd4dc..870c8bebd 100644 --- a/ext/dapr-ext-workflow/setup.cfg +++ b/ext/dapr-ext-workflow/setup.cfg @@ -20,7 +20,7 @@ project_urls = Source = https://github.com/dapr/python-sdk [options] -python_requires = >=3.9 +python_requires = >=3.10 packages = find_namespace: include_package_data = True install_requires = diff --git a/ext/dapr-ext-workflow/tests/README.md b/ext/dapr-ext-workflow/tests/README.md new file mode 100644 index 000000000..6759a362d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/README.md @@ -0,0 +1,94 @@ +## Workflow tests: unit, integration, and custom ports + +This directory contains unit tests (no sidecar required) and integration tests (require a running sidecar/runtime). + +### Prereqs + +- Python 3.11+ (tox will create an isolated venv) +- Dapr sidecar for integration tests (HTTP and gRPC ports) +- Optional: Durable Task gRPC endpoint for DT e2e tests + +### Run all tests via tox (recommended) + +```bash +tox -e py311 +``` + +This runs: +- Core SDK tests (unittest) +- Workflow extension unit tests (pytest) +- Workflow extension integration tests (pytest) if your sidecar/runtime is reachable + +### Run only workflow unit tests + +Unit tests live at `ext/dapr-ext-workflow/tests` excluding the `integration/` subfolder. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +### Run workflow integration tests + +Integration tests live under `ext/dapr-ext-workflow/tests/integration/` and require a running sidecar/runtime. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests/integration +``` + +If tests cannot reach your sidecar/runtime, they will skip or fail fast depending on the specific test. + +### Configure custom sidecar ports/endpoints + +The SDK reads connection settings from env vars (see `dapr.conf.global_settings`). Use these to point tests at custom ports: + +- Dapr gRPC: + - `DAPR_GRPC_ENDPOINT` (preferred): endpoint string, e.g. `dns:127.0.0.1:50051` + - or `DAPR_RUNTIME_HOST` and `DAPR_GRPC_PORT`, e.g. `DAPR_RUNTIME_HOST=127.0.0.1`, `DAPR_GRPC_PORT=50051` + +- Dapr HTTP (only for HTTP-based tests): + - `DAPR_HTTP_ENDPOINT`: e.g. `http://127.0.0.1:3600` + - or `DAPR_RUNTIME_HOST` and `DAPR_HTTP_PORT`, e.g. `DAPR_HTTP_PORT=3600` + +Examples: +```bash +# Use custom gRPC 50051 and HTTP 3600 +export DAPR_GRPC_ENDPOINT=dns:127.0.0.1:50051 +export DAPR_HTTP_ENDPOINT=http://127.0.0.1:3600 + +# Alternatively, using host/port pairs +export DAPR_RUNTIME_HOST=127.0.0.1 +export DAPR_GRPC_PORT=50051 +export DAPR_HTTP_PORT=3600 + +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Note: For gRPC, avoid `http://` or `https://` schemes. Use `dns:host:port` or just set host/port separately. + +### Durable Task e2e tests (optional) + +Some tests (e.g., `integration/test_async_e2e_dt.py`) talk directly to a Durable Task gRPC endpoint. They use: + +- `DURABLETASK_GRPC_ENDPOINT` (default `localhost:56178`) + +If your DT runtime listens elsewhere: +```bash +export DURABLETASK_GRPC_ENDPOINT=127.0.0.1:56179 +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py +``` + + + + diff --git a/ext/dapr-ext-workflow/tests/_fakes.py b/ext/dapr-ext-workflow/tests/_fakes.py new file mode 100644 index 000000000..603c89173 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/_fakes.py @@ -0,0 +1,68 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + + +class FakeOrchestrationContext: + def __init__( + self, + *, + instance_id: str = 'wf-1', + current_utc_datetime: datetime | None = None, + is_replaying: bool = False, + workflow_name: str = 'wf', + trace_parent: str | None = None, + trace_state: str | None = None, + orchestration_span_id: str | None = None, + workflow_attempt: int | None = None, + ) -> None: + self.instance_id = instance_id + self.current_utc_datetime = ( + current_utc_datetime if current_utc_datetime else datetime(2025, 1, 1) + ) + self.is_replaying = is_replaying + self.workflow_name = workflow_name + self.trace_parent = trace_parent + self.trace_state = trace_state + self.orchestration_span_id = orchestration_span_id + self.workflow_attempt = workflow_attempt + + +class FakeActivityContext: + def __init__( + self, + *, + orchestration_id: str = 'wf-1', + task_id: int = 1, + attempt: int | None = None, + trace_parent: str | None = None, + trace_state: str | None = None, + workflow_span_id: str | None = None, + ) -> None: + self.orchestration_id = orchestration_id + self.task_id = task_id + self.trace_parent = trace_parent + self.trace_state = trace_state + self.workflow_span_id = workflow_span_id + + +def make_orch_ctx(**overrides: Any) -> FakeOrchestrationContext: + return FakeOrchestrationContext(**overrides) + + +def make_act_ctx(**overrides: Any) -> FakeActivityContext: + return FakeActivityContext(**overrides) diff --git a/ext/dapr-ext-workflow/tests/conftest.py b/ext/dapr-ext-workflow/tests/conftest.py new file mode 100644 index 000000000..f20a225e7 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/conftest.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Ensure tests prefer the local python-sdk repository over any installed site-packages +# This helps when running pytest directly (outside tox/CI), so changes in the repo are exercised. +from __future__ import annotations # noqa: I001 + +import sys +from pathlib import Path +import importlib +import pytest + + +def pytest_configure(config): # noqa: D401 (pytest hook) + """Pytest configuration hook that prepends the repo root to sys.path. + + This ensures `import dapr` resolves to the local source tree when running tests directly. + Under tox/CI (editable installs), this is a no-op but still safe. + """ + try: + # ext/dapr-ext-workflow/tests/conftest.py -> repo root is 3 parents up + repo_root = Path(__file__).resolve().parents[3] + except Exception: + return + + repo_str = str(repo_root) + if repo_str not in sys.path: + sys.path.insert(0, repo_str) + + # Best-effort diagnostic: show where dapr was imported from + try: + dapr_mod = importlib.import_module('dapr') + dapr_path = Path(getattr(dapr_mod, '__file__', '')).resolve() + where = 'site-packages' if 'site-packages' in str(dapr_path) else 'local-repo' + print(f'[dapr-ext-workflow/tests] dapr resolved from {where}: {dapr_path}', file=sys.stderr) + except Exception: + # If dapr isn't importable yet, that's fine; tests importing it later will use modified sys.path + pass + + +@pytest.fixture(autouse=True) +def cleanup_workflow_registrations(request): + """Clean up workflow/activity registration markers after each test. + + This prevents test interference when the same function objects are reused across tests. + The workflow runtime marks functions with _dapr_alternate_name and _activity_registered + attributes, which can cause 'already registered' errors in subsequent tests. + """ + yield # Run the test + + # After test completes, clean up functions defined in the test module + test_module = sys.modules.get(request.module.__name__) + if test_module: + for name in dir(test_module): + obj = getattr(test_module, name, None) + if callable(obj) and hasattr(obj, '__dict__'): + try: + # Only clean up if __dict__ is writable (not mappingproxy) + if isinstance(obj.__dict__, dict): + obj.__dict__.pop('_dapr_alternate_name', None) + obj.__dict__.pop('_activity_registered', None) + except (AttributeError, TypeError): + # Skip objects with read-only __dict__ + pass diff --git a/ext/dapr-ext-workflow/tests/integration/__init__.py b/ext/dapr-ext-workflow/tests/integration/__init__.py new file mode 100644 index 000000000..f4bf93490 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +"""Integration tests for dapr-ext-workflow.""" diff --git a/ext/dapr-ext-workflow/tests/integration/dapr_test_utils.py b/ext/dapr-ext-workflow/tests/integration/dapr_test_utils.py new file mode 100644 index 000000000..1fadc5614 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/dapr_test_utils.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- + +""" +Utility functions for Dapr integration/e2e tests. + +Provides helpers for starting Dapr sidecars and managing test infrastructure during integratino (e2e) tests +""" + +from __future__ import annotations + +import os +import shutil +import socket +import subprocess +import tempfile +import time + +import pytest + + +def is_dapr_cli_available() -> bool: + """Check if the Dapr CLI is installed and available.""" + return shutil.which('dapr') is not None + + +# Skip decorator for tests that require Dapr CLI +skip_if_no_dapr = pytest.mark.skipif( + not is_dapr_cli_available(), + reason='Dapr CLI is not installed. Install from https://docs.dapr.io/getting-started/install-dapr-cli/', +) + + +def is_runtime_available(host: str, port: int) -> bool: + """Check if a Dapr runtime is available at the given host:port.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + except Exception: + return False + + +def start_dapr_sidecar( + app_id: str, grpc_port: int, http_port: int, keep_alive_seconds: int = 600 +) -> subprocess.Popen: + """Start a Dapr sidecar using dapr CLI. + + Args: + app_id: Application ID for the Dapr sidecar + grpc_port: gRPC port for Dapr API + http_port: HTTP port for Dapr API + + Returns: + Process handle for the Dapr sidecar + + Raises: + RuntimeError: If the sidecar fails to start + """ + # Create temporary components directory with state store + components_dir = tempfile.mkdtemp(prefix=f'dapr-components-{app_id}-') + statestore_path = os.path.join(components_dir, 'statestore.yaml') + + statestore_yaml = """apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: statestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: actorStateStore + value: "true" +""" + + with open(statestore_path, 'w') as f: + f.write(statestore_yaml) + + cmd = [ + 'dapr', + 'run', + '--app-id', + app_id, + '--dapr-grpc-port', + str(grpc_port), + '--dapr-http-port', + str(http_port), + '--resources-path', + components_dir, + '--log-level', + 'info', + '--', + 'sleep', + str(keep_alive_seconds), + ] + + print( + f'[Setup] Starting dapr for {app_id} on grpc={grpc_port}, http={http_port} with components in {components_dir}', + flush=True, + ) + + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + # Wait for sidecar to start + for i in range(30): # 30 second timeout + poll_result = proc.poll() + if poll_result is not None: + raise RuntimeError(f'dapr run for {app_id} exited with code {poll_result}') + if is_runtime_available('127.0.0.1', grpc_port): + print(f'[Setup] dapr for {app_id} is ready!', flush=True) + # Store components directory for cleanup + proc.components_dir = components_dir + return proc + if i % 5 == 0: + print(f'[Setup] Waiting for dapr {app_id}... ({i}/30)', flush=True) + time.sleep(1) + + # If we get here, it failed + proc.kill() + raise RuntimeError(f'Failed to start dapr for {app_id} - timeout after 30s') + + +def stop_dapr_sidecar(proc: subprocess.Popen, app_id: str): + """Stop a dapr sidecar process. + + Args: + proc: Process handle for the Dapr sidecar + app_id: Application ID (for logging) + """ + print(f'[Cleanup] Stopping dapr for {app_id}', flush=True) + try: + proc.terminate() + proc.wait(timeout=5) + print(f'[Cleanup] dapr for {app_id} stopped', flush=True) + except subprocess.TimeoutExpired: + print(f'[Cleanup] Force killing dapr for {app_id}', flush=True) + proc.kill() + proc.wait() + + # Clean up temporary components directory if it exists + if hasattr(proc, 'components_dir'): + try: + shutil.rmtree(proc.components_dir) + except Exception: + pass + + +def dapr_sidecar_fixture( + app_id: str, grpc_port: int, http_port: int, keep_alive_seconds: int = 600 +): + """Create a pytest fixture for a Dapr sidecar. + + This is a fixture factory that can be used to create reusable sidecar fixtures. + + Example: + @pytest.fixture(scope='module') + def my_dapr_sidecar(): + return dapr_sidecar_fixture('my-app', 50001, 3001) + + Args: + app_id: Application ID for the Dapr sidecar + grpc_port: gRPC port for Dapr API + http_port: HTTP port for Dapr API + + Yields: + Process handle for the Dapr sidecar + """ + proc = None + try: + print( + f'[Setup] Starting dapr for {app_id} on grpc={grpc_port}, http={http_port}', flush=True + ) + proc = start_dapr_sidecar(app_id, grpc_port, http_port, keep_alive_seconds) + yield proc + except Exception as e: + pytest.skip(f'Could not start dapr for {app_id}: {e}') + finally: + if proc: + stop_dapr_sidecar(proc, app_id) diff --git a/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py new file mode 100644 index 000000000..fe38b8ce8 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- + +""" +Async e2e tests using durabletask worker/client directly. + +These validate basic orchestration behavior against a running sidecar +to isolate environment issues from WorkflowRuntime wiring. +""" + +from __future__ import annotations + +import time + +import pytest +from durabletask.aio import AsyncWorkflowContext +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + +from .dapr_test_utils import dapr_sidecar_fixture, skip_if_no_dapr + +pytestmark = [pytest.mark.e2e, skip_if_no_dapr] + +# Dapr configuration for e2e tests +DAPR_GRPC_PORT = 50012 +DAPR_HTTP_PORT = 3502 + + +@pytest.fixture(scope='module') +def dapr_sidecar(): + """Start Dapr sidecar for all e2e tests in this module.""" + yield from dapr_sidecar_fixture('test-e2e-dt', DAPR_GRPC_PORT, DAPR_HTTP_PORT) + + +def get_workger_client_worker() -> tuple[TaskHubGrpcWorker, TaskHubGrpcClient]: + return TaskHubGrpcWorker(host_address=f'localhost:{DAPR_GRPC_PORT}'), TaskHubGrpcClient( + host_address=f'localhost:{DAPR_GRPC_PORT}' + ) + + +def test_dt_simple_activity_e2e(dapr_sidecar): + worker, client = get_workger_client_worker() + + def act(ctx, x: int) -> int: + return x * 3 + + worker.add_activity(act) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, x: int) -> int: + return await ctx.call_activity(act, input=x) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-act-{int(time.time() * 1000)}' + client.schedule_new_orchestration(orch, input=5, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + # Output is JSON serialized scalar + assert st.serialized_output.strip() in ('15', '"15"') + finally: + try: + worker.stop() + except Exception: + pass + + +def test_dt_timer_e2e(dapr_sidecar): + worker, client = get_workger_client_worker() + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, delay: float) -> dict: + start = ctx.now() + await ctx.create_timer(delay) + end = ctx.now() + return {'start': start.isoformat(), 'end': end.isoformat(), 'delay': delay} + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-timer-{int(time.time() * 1000)}' + delay = 1.0 + client.schedule_new_orchestration(orch, input=delay, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass + + +def test_dt_sub_orchestrator_e2e(dapr_sidecar): + worker, client = get_workger_client_worker() + + def act(ctx, s: str) -> str: + return f'A:{s}' + + worker.add_activity(act) + + async def child(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] child start', s) + try: + res = await ctx.call_activity(act, input=s) + print('[E2E DEBUG] child done', res) + return res + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] child exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + # Explicit registration to avoid decorator replacing symbol with a string in newer versions + worker.add_async_orchestrator(child) + + async def parent(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] parent start', s) + try: + c = await ctx.call_sub_orchestrator(child, input=s) + out = f'P:{c}' + print('[E2E DEBUG] parent done', out) + return out + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] parent exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + worker.add_async_orchestrator(parent) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-sub-{int(time.time() * 1000)}' + print('[E2E DEBUG] scheduling instance', iid) + client.schedule_new_orchestration(parent, input='x', instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + if st.runtime_status.name != 'COMPLETED': + # Print orchestration state details to aid debugging + print('[E2E DEBUG] orchestration FAILED; details:') + to_json = getattr(st, 'to_json', None) + if callable(to_json): + try: + print(to_json()) + except Exception: + pass + print('status=', getattr(st, 'runtime_status', None)) + print('output=', getattr(st, 'serialized_output', None)) + print('failure=', getattr(st, 'failure_details', None)) + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass + + +def test_dt_async_activity_e2e(dapr_sidecar): + """Test async activities with actual async I/O operations.""" + worker, client = get_workger_client_worker() + + # Define an async activity that performs async work + async def async_io_activity(ctx, x: int) -> dict: + """Async activity that simulates I/O-bound work.""" + import asyncio + + # Simulate async I/O (e.g., network request, database query) + await asyncio.sleep(0.01) + result = x * 5 + await asyncio.sleep(0.01) + return {'input': x, 'output': result, 'async': True} + + worker.add_activity(async_io_activity) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, x: int) -> dict: + # Call async activity + result = await ctx.call_activity(async_io_activity, input=x) + return result + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-async-act-{int(time.time() * 1000)}' + client.schedule_new_orchestration(orch, input=7, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify the output contains expected values + import json + + output = json.loads(st.serialized_output) + assert output['input'] == 7 + assert output['output'] == 35 + assert output['async'] is True + finally: + try: + worker.stop() + except Exception: + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_cross_app_interceptors_e2e.py b/ext/dapr-ext-workflow/tests/integration/test_cross_app_interceptors_e2e.py new file mode 100644 index 000000000..40fc07fd7 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_cross_app_interceptors_e2e.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- + +""" +E2E tests for cross-app workflow interceptors. + +These tests require multiple Dapr sidecars to be running. +Run with: pytest -m e2e tests/integration/test_cross_app_interceptors_e2e.py +""" + +from __future__ import annotations + +import multiprocessing +import time +from dataclasses import replace +from datetime import timedelta + +import pytest +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + DaprWorkflowClient, + RetryPolicy, + WorkflowRuntime, +) + +from .dapr_test_utils import dapr_sidecar_fixture, skip_if_no_dapr + +pytestmark = [pytest.mark.e2e, skip_if_no_dapr] + +# Configuration for test apps +app1 = { + 'id': 'cross-app-test-app1', + 'grpc_port': 50101, + 'http_port': 3101, +} +app2 = { + 'id': 'cross-app-test-app2', + 'grpc_port': 50102, + 'http_port': 3102, +} + + +@pytest.fixture(scope='module') +def dapr_app1_sidecar(): + """Start dapr sidecar for app1.""" + yield from dapr_sidecar_fixture(app1['id'], app1['grpc_port'], app1['http_port']) + + +@pytest.fixture(scope='module') +def dapr_app2_sidecar(): + """Start dapr sidecar for app2.""" + yield from dapr_sidecar_fixture(app2['id'], app2['grpc_port'], app2['http_port']) + + +def _run_app2_worker(retry_count: multiprocessing.Value): + """Run app2 workflow worker (activity provider).""" + print('[App2 Worker] Starting...', flush=True) + runtime = WorkflowRuntime(host='127.0.0.1', port=str(app2['grpc_port'])) + + @runtime.activity(name='remote_activity') + def remote_activity(ctx, input_data): + retry_count.value += 1 # only one process can access this value so no need to lock + print( + f'[App2 Worker] remote_activity called (attempt {retry_count.value}) with: {input_data}', + flush=True, + ) + + # Fail first 2 attempts to test retry policy + if retry_count.value <= 2: + print(f'[App2 Worker] Simulating failure on attempt {retry_count.value}', flush=True) + raise Exception(f'Simulated failure on attempt {retry_count.value}') + + print(f'[App2 Worker] Success on attempt {retry_count.value}', flush=True) + return f'remote-result-{input_data}' + + print('[App2 Worker] Starting runtime...', flush=True) + runtime.start() + print('[App2 Worker] Runtime started, waiting for work...', flush=True) + + # Keep running until parent terminates us + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + pass + finally: + print(f'[App2 Worker] Final retry count: {retry_count.value}', flush=True) + print('[App2 Worker] Shutting down...', flush=True) + runtime.shutdown() + + +def test_cross_app_interceptor_modifies_retry_policy(dapr_app1_sidecar, dapr_app2_sidecar): + """Test that interceptors can modify retry_policy for cross-app activity calls. + + This test requires two Dapr sidecars: + - app1: Runs workflow with interceptor that sets retry policy + - app2: Provides the remote activity + """ + print('\n[Test] Starting app2 worker process...', flush=True) + retry_count = multiprocessing.Manager().Value('i', 0) # type: ignore + # Start app2 worker in background process + app2_process = multiprocessing.Process(target=_run_app2_worker, args=(retry_count,)) + app2_process.start() + + try: + # Give app2 worker time to start and register activities + print('[Test] Waiting for app2 worker to register activities...', flush=True) + time.sleep(10) + + # Track interceptor calls + interceptor_calls = [] + + class TestRetryInterceptor(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + print( + f'[Test] Interceptor called: {request.activity_name}, app_id={request.app_id}', + flush=True, + ) + # Record the call + interceptor_calls.append( + { + 'activity_name': request.activity_name, + 'app_id': request.app_id, + 'had_retry_policy': request.retry_policy is not None, + } + ) + + # Add retry policy if none exists + retry_policy = request.retry_policy + if retry_policy is None: + retry_policy = RetryPolicy( + max_number_of_attempts=3, + first_retry_interval=timedelta(milliseconds=100), + max_retry_interval=timedelta(seconds=2), + ) + + return next(replace(request, retry_policy=retry_policy)) + + print('[Test] Creating app1 runtime with interceptor...', flush=True) + runtime = WorkflowRuntime( + host='127.0.0.1', + port=str(app1['grpc_port']), + workflow_outbound_interceptors=[TestRetryInterceptor()], + ) + + @runtime.workflow(name='cross_app_workflow') + def cross_app_workflow(ctx, input_data): + print(f'[Test] Workflow executing with input: {input_data}', flush=True) + # Call cross-app activity - should go through interceptor + result = yield ctx.call_activity('remote_activity', input=input_data, app_id=app2['id']) + print(f'[Test] Workflow got result: {result}', flush=True) + return result + + print('[Test] Starting app1 runtime...', flush=True) + runtime.start() + time.sleep(5) # Give runtime time to start + + try: + print('[Test] Creating workflow client...', flush=True) + client = DaprWorkflowClient(host='127.0.0.1', port=str(app1['grpc_port'])) + instance_id = f'test-cross-app-{int(time.time())}' + + print(f'[Test] Scheduling workflow with instance_id: {instance_id}', flush=True) + # Schedule and run workflow + client.schedule_new_workflow( + workflow=cross_app_workflow, instance_id=instance_id, input='test-data' + ) + + print('[Test] Waiting for workflow completion...', flush=True) + # Wait for completion - should succeed after retries + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + + print( + f'[Test] Workflow completed with status: {state.runtime_status.name if state else "None"}', + flush=True, + ) + + # Verify workflow completed successfully (after retries) + assert state is not None, 'Workflow state should not be None' + print(f'[Test] Workflow status: {state.runtime_status.name}', flush=True) + if state.runtime_status.name != 'COMPLETED': + print(f'[Test] Workflow failed: {state.serialized_output}', flush=True) + assert state.runtime_status.name == 'COMPLETED', ( + f'Expected COMPLETED but got {state.runtime_status.name}' + ) + + # Verify the workflow result is correct (proves retry succeeded) + import json + + result = json.loads(state.serialized_output) if state.serialized_output else None + print(f'[Test] Workflow result: {result}', flush=True) + assert result == 'remote-result-test-data', ( + f'Expected "remote-result-test-data" but got {result}' + ) + + # Verify interceptor was called + print(f'[Test] Interceptor calls: {interceptor_calls}', flush=True) + assert len(interceptor_calls) >= 1, ( + f'Expected at least 1 interceptor call, got {len(interceptor_calls)}' + ) + assert interceptor_calls[0]['activity_name'] == 'remote_activity' + assert interceptor_calls[0]['app_id'] == app2['id'] + assert interceptor_calls[0]['had_retry_policy'] is False + assert retry_count.value == 3, f'Expected retry count to be 3, got {retry_count.value}' + + print('[Test] All assertions passed! Activity succeeded after retries.', flush=True) + + finally: + print('[Test] Shutting down app1 runtime...', flush=True) + runtime.shutdown() + + finally: + # Clean up app2 worker + print('[Test] Terminating app2 worker...', flush=True) + app2_process.terminate() + app2_process.join(timeout=5) + if app2_process.is_alive(): + print('[Test] Force killing app2 worker...', flush=True) + app2_process.kill() + app2_process.join() + + +if __name__ == '__main__': + # For manual testing + pytest.main([__file__, '-v', '-m', 'e2e']) diff --git a/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py new file mode 100644 index 000000000..b841883bf --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py @@ -0,0 +1,1055 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import dataclasses +import time +from datetime import timedelta +from typing import Any + +import pytest +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + DaprWorkflowContext, + RetryPolicy, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ClientInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, +) + +from .dapr_test_utils import dapr_sidecar_fixture, skip_if_no_dapr + +pytestmark = [pytest.mark.e2e, skip_if_no_dapr] + +# Dapr configuration for integration tests - use function scope with unique ports +_port_counter = 0 + +# Whether to purge workflows after tests +purge = False + + +@pytest.fixture(scope='function') +def dapr_config(): + """Allocate unique Dapr configuration for each test. + + We had problems with module-wise tests hanging when running multiple tests had run. + """ + global _port_counter + _port_counter += 1 + return { + 'app_id': f'test-integration-{_port_counter}', + 'grpc_port': 50011 + _port_counter, + 'http_port': 3501 + _port_counter, + } + + +@pytest.fixture(scope='function') +def dapr_sidecar(dapr_config): + """Start fresh Dapr sidecar for each test to avoid degradation.""" + yield from dapr_sidecar_fixture( + dapr_config['app_id'], + dapr_config['grpc_port'], + dapr_config['http_port'], + keep_alive_seconds=120, + ) + + +def wf_client(dapr_config, interceptors: list[ClientInterceptor] = None): + return DaprWorkflowClient( + host='127.0.0.1', port=str(dapr_config['grpc_port']), interceptors=interceptors + ) + + +def wfr( + dapr_config, + runtime_interceptors: list[BaseRuntimeInterceptor] = None, + workflow_outbound_interceptors: list[BaseWorkflowOutboundInterceptor] = None, +): + return WorkflowRuntime( + host='127.0.0.1', + port=str(dapr_config['grpc_port']), + runtime_interceptors=runtime_interceptors, + workflow_outbound_interceptors=workflow_outbound_interceptors, + ) + + +def test_integration_suspension_and_buffering(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='suspend_orchestrator_async') + async def suspend_orchestrator(ctx: AsyncWorkflowContext): + # Expose suspension state via custom status + ctx.set_custom_status({'is_suspended': getattr(ctx, 'is_suspended', False)}) + # Wait for 'resume_event' and then complete + data = await ctx.wait_for_external_event('resume_event') + return {'resumed_with': data} + + runtime.start() + try: + # Allow connection to stabilize before scheduling + time.sleep(3) + + client = wf_client(dapr_config) + instance_id = f'suspend-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=suspend_orchestrator, instance_id=instance_id) + + # Wait until started + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Pause and verify state becomes SUSPENDED and custom status updates on next activation + client.pause_workflow(instance_id) + # Give the worker time to process suspension + time.sleep(1) + state = client.get_workflow_state(instance_id) + assert state is not None + assert state.runtime_status.name in ( + 'SUSPENDED', + 'RUNNING', + ) # some hubs report SUSPENDED explicitly + + # While suspended, raise the event; it should buffer + client.raise_workflow_event(instance_id, 'resume_event', data={'ok': True}) + + # Resume and expect completion + client.resume_workflow(instance_id) + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +def test_integration_generator_metadata_propagation(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.activity(name='recv_md_gen') + def recv_md_gen(ctx, _=None): + return ctx.get_metadata() or {} + + @runtime.workflow(name='gen_parent_sets_md') + def parent_gen(ctx: DaprWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'tier': 'gold'}) + md = yield ctx.call_activity(recv_md_gen, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'gen-md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_gen, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('tier') == 'gold' + finally: + runtime.shutdown() + + +def test_integration_trace_context_child_workflow(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + return { + 'tp': getattr(ctx, 'trace_parent', None), + 'ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + } + + @runtime.async_workflow(name='child_trace') + async def child(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_tp': getattr(ctx, 'trace_parent', None), + 'wf_ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + 'act': await ctx.call_activity(trace_probe, input=None), + } + + @runtime.async_workflow(name='parent_trace') + async def parent(ctx: AsyncWorkflowContext): + child_out = await ctx.call_child_workflow(child, input=None) + return { + 'parent_tp': getattr(ctx, 'trace_parent', None), + 'parent_span': getattr(ctx, 'workflow_span_id', None), + 'child': child_out, + } + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'trace-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + + # TODO: assert more specifically when we have trace context information + + # Parent (engine-provided fields may be absent depending on runtime build/config) + assert isinstance(data.get('parent_tp'), (str, type(None))) + assert isinstance(data.get('parent_span'), (str, type(None))) + # Child orchestrator fields + _child = data.get('child') or {} + assert isinstance(_child.get('wf_tp'), (str, type(None))) + assert isinstance(_child.get('wf_span'), (str, type(None))) + # Activity fields under child + act = _child.get('act') or {} + assert isinstance(act.get('tp'), (str, type(None))) + assert isinstance(act.get('wf_span'), (str, type(None))) + + finally: + runtime.shutdown() + + +def test_integration_trace_context_child_workflow_injected_metadata(dapr_sidecar, dapr_config): + # Deterministic trace propagation using interceptors via durable metadata + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_id' + + class InjectTraceClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class InjectTraceOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next(dataclasses.replace(request, metadata=md)) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next(dataclasses.replace(request, metadata=md)) + + class RestoreTraceRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Ensure metadata arrives + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + runtime = wfr( + dapr_config, + runtime_interceptors=[RestoreTraceRuntime()], + workflow_outbound_interceptors=[InjectTraceOutbound()], + ) + + @runtime.activity(name='trace_probe2') + def trace_probe2(ctx, _=None): + return getattr(ctx, 'get_metadata', lambda: {})().get(TRACE_KEY) + + @runtime.async_workflow(name='child_trace2') + async def child2(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'act_md': await ctx.call_activity(trace_probe2, input=None), + } + + @runtime.async_workflow(name='parent_trace2') + async def parent2(ctx: AsyncWorkflowContext): + out = await ctx.call_child_workflow(child2, input=None) + return { + 'parent_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'child': out, + } + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config, interceptors=[InjectTraceClient()]) + iid = f'trace-child-md-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent2, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + assert data.get('parent_md') == 'sdk-trace-123' + child = data.get('child') or {} + assert child.get('wf_md') == 'sdk-trace-123' + assert child.get('act_md') == 'sdk-trace-123' + finally: + runtime.shutdown() + + +def test_integration_termination_semantics(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='termination_orchestrator_async') + async def termination_orchestrator(ctx: AsyncWorkflowContext): + # Long timer; test will terminate before it fires + await ctx.create_timer(300.0) + return 'not-reached' + + print(list(runtime._WorkflowRuntime__worker._registry.orchestrators.keys())) + + runtime.start() + try: + time.sleep(3) + + client = wf_client(dapr_config) + instance_id = f'term-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=termination_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Terminate and assert TERMINATED state, not raising inside orchestrator + client.terminate_workflow(instance_id, output='terminated') + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'TERMINATED' + finally: + runtime.shutdown() + + +def test_integration_when_any_first_wins(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='when_any_async') + async def when_any_orchestrator(ctx: AsyncWorkflowContext): + first = await ctx.when_any( + [ + ctx.wait_for_external_event('go'), + ctx.create_timer(300.0), + ] + ) + # Return a simple, serializable value (winner's result) to avoid output serialization issues + try: + result = first.get_result() + except Exception: + result = None + return {'winner_result': result} + + runtime.start() + try: + time.sleep(2) + + client = wf_client(dapr_config) + instance_id = f'whenany-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=when_any_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + # Confirm RUNNING state before raising event (mitigates race conditions) + try: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is None + or getattr(st, 'runtime_status', None) is None + or st.runtime_status.name != 'RUNNING' + ): + end = time.time() + 10 + while time.time() < end: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is not None + and getattr(st, 'runtime_status', None) is not None + and st.runtime_status.name == 'RUNNING' + ): + break + time.sleep(0.2) + except Exception: + pass + + # Raise event immediately to win the when_any + client.raise_workflow_event(instance_id, 'go', data={'ok': True}) + + # Brief delay to allow event processing, then strictly use DaprWorkflowClient + time.sleep(1.0) + final = None + try: + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + except TimeoutError: + final = None + if final is None: + deadline = time.time() + 30 + while time.time() < deadline: + s = client.get_workflow_state(instance_id, fetch_payloads=False) + if s is not None and getattr(s, 'runtime_status', None) is not None: + if s.runtime_status.name in ('COMPLETED', 'FAILED', 'TERMINATED'): + final = s + break + time.sleep(0.5) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + # TODO: when sidecar exposes command diagnostics, assert only one command set was emitted + finally: + runtime.shutdown() + + +def test_integration_async_activity_completes(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.activity(name='echo_int') + def echo_act(ctx, x: int) -> int: + return x + + @runtime.async_workflow(name='async_activity_once') + async def wf(ctx: AsyncWorkflowContext): + out = await ctx.call_activity(echo_act, input=7) + return out + + runtime.start() + iid = None + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'act-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + if state.runtime_status.name != 'COMPLETED': + fd = getattr(state, 'failure_details', None) + msg = getattr(fd, 'message', None) if fd else None + et = getattr(fd, 'error_type', None) if fd else None + print(f'[INTEGRATION DEBUG] Failure details: {et} {msg}') + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + if iid: + try: + client.purge_workflow(iid, recursive=True) + except Exception: + pass + + +def test_integration_metadata_outbound_to_activity(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.activity(name='recv_md') + def recv_md(ctx, _=None): + md = ctx.get_metadata() if hasattr(ctx, 'get_metadata') else {} + return md + + @runtime.async_workflow(name='wf_with_md') + async def wf(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme'}) + md = await ctx.call_activity(recv_md, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +def test_integration_metadata_outbound_to_child_workflow(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='child_recv_md') + async def child(ctx: AsyncWorkflowContext, _=None): + # Echo inbound metadata + return ctx.get_metadata() or {} + + @runtime.async_workflow(name='parent_sets_md') + async def parent(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'role': 'user'}) + out = await ctx.call_child_workflow(child, input=None) + return out + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'md-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + # Validate output has metadata keys + data = state.to_json() + import json as _json + + out = _json.loads(data.get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('role') == 'user' + finally: + runtime.shutdown() + + +def test_integration_trace_context_with_runtime_interceptors(dapr_sidecar, dapr_config): + """E2E: Verify trace_parent and orchestration_span_id via runtime interceptors.""" + records = { # captured by interceptor + 'wf_tp': None, + 'wf_span': None, + 'act_tp': None, + 'act_span': None, + } + + class TraceInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['wf_tp'] = getattr(ctx, 'trace_parent', None) + records['wf_span'] = getattr(ctx, 'workflow_span_id', None) + except Exception: + pass + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['act_tp'] = getattr(ctx, 'trace_parent', None) + # Activity contexts don't have orchestration_span_id; capture task span if present + records['act_span'] = getattr(ctx, 'activity_span_id', None) + except Exception: + pass + return next(request) + + runtime = wfr(dapr_config, runtime_interceptors=[TraceInterceptor()]) + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + # Return trace context seen inside activity + return { + 'trace_parent': getattr(ctx, 'trace_parent', None), + 'trace_state': getattr(ctx, 'trace_state', None), + } + + @runtime.async_workflow(name='trace_parent_wf') + async def wf(ctx: AsyncWorkflowContext): + # Access orchestration span id and trace parent from workflow context + _ = getattr(ctx, 'workflow_span_id', None) + _ = getattr(ctx, 'trace_parent', None) + return await ctx.call_activity(trace_probe, input=None) + + runtime.start() + try: + time.sleep(3) + client = wf_client(dapr_config) + iid = f'trace-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + # Activity returned strings (may be empty); assert types + assert isinstance(out.get('trace_parent'), (str, type(None))) + assert isinstance(out.get('trace_state'), (str, type(None))) + # Interceptor captured workflow and activity contexts + wf_tp = records['wf_tp'] + wf_span = records['wf_span'] + act_tp = records['act_tp'] + # TODO: assert more specifically when we have trace context information + assert isinstance(wf_tp, (str, type(None))) + assert isinstance(wf_span, (str, type(None))) + assert isinstance(act_tp, (str, type(None))) + # If we have a workflow span id, it should appear as parent-id inside activity traceparent + if isinstance(wf_span, str) and wf_span and isinstance(act_tp, str) and act_tp: + assert wf_span.lower() in act_tp.lower() + finally: + runtime.shutdown() + + +def test_integration_runtime_shutdown_is_clean(dapr_sidecar, dapr_config): + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='noop') + async def noop(ctx: AsyncWorkflowContext): + return 'ok' + + runtime.start() + try: + time.sleep(2) + client = wf_client(dapr_config) + iid = f'shutdown-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=noop, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=30) + assert st is not None and st.runtime_status.name == 'COMPLETED' + finally: + # Call shutdown multiple times to ensure idempotent and clean behavior + for _ in range(3): + try: + runtime.shutdown() + except Exception: + # Test should not raise even if worker logs cancellation warnings + assert False, 'runtime.shutdown() raised unexpectedly' + # Recreate and shutdown again to ensure no lingering background threads break next startup + rt2 = wfr(dapr_config) + rt2.start() + try: + time.sleep(1) + finally: + try: + rt2.shutdown() + except Exception: + assert False, 'second runtime.shutdown() raised unexpectedly' + + +def test_integration_continue_as_new_outbound_interceptor_metadata(dapr_sidecar, dapr_config): + # Verify continue_as_new outbound interceptor can inject metadata carried to the new run + from dapr.ext.workflow import BaseWorkflowOutboundInterceptor + + INJECT_KEY = 'injected' + + class InjectOnContinueAsNew(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(INJECT_KEY, 'yes') + request.metadata = md + return next(request) + + runtime = wfr(dapr_config, workflow_outbound_interceptors=[InjectOnContinueAsNew()]) + + @runtime.workflow(name='continue_as_new_probe') + def wf(ctx, arg: dict | None = None): + if not arg or arg.get('phase') != 'second': + ctx.set_metadata({'tenant': 'acme'}) + # carry over existing metadata; interceptor will also inject + ctx.continue_as_new({'phase': 'second'}, carryover_metadata=True) + return # Must not yield after continue_as_new + # Second run: return inbound metadata observed + return ctx.get_metadata() or {} + + runtime.start() + try: + time.sleep(2) + client = wf_client(dapr_config) + iid = f'can-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Confirm both carried and injected metadata are present + assert out.get('tenant') == 'acme' + assert out.get(INJECT_KEY) == 'yes' + finally: + runtime.shutdown() + + +def test_integration_child_workflow_attempt_exposed(dapr_sidecar, dapr_config): + # Verify that child workflow ctx exposes workflow_attempt + runtime = wfr(dapr_config) + + @runtime.async_workflow(name='child_probe_attempt') + async def child_probe_attempt(ctx: AsyncWorkflowContext, _=None): + att = getattr(ctx, 'workflow_attempt', None) + return {'wf_attempt': att} + + @runtime.async_workflow(name='parent_calls_child_for_attempt') + async def parent_calls_child_for_attempt(ctx: AsyncWorkflowContext): + return await ctx.call_child_workflow(child_probe_attempt, input=None) + + runtime.start() + try: + time.sleep(2) + client = wf_client(dapr_config) + iid = f'child-attempt-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_calls_child_for_attempt, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + val = out.get('wf_attempt', None) + assert (val is None) or isinstance(val, int) + finally: + runtime.shutdown() + + +def test_integration_async_contextvars_trace_propagation(dapr_sidecar, dapr_config): + # Demonstrates contextvars-based trace propagation via interceptors in async workflows + import contextvars + import json as _json + + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_ctx' + current_trace: contextvars.ContextVar[str | None] = contextvars.ContextVar( + 'trace', default=None + ) + + class CVClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'wf-parent') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class CVOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next(dataclasses.replace(request, metadata=md)) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next(dataclasses.replace(request, metadata=md)) + + class CVRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + runtime = wfr( + dapr_config, + runtime_interceptors=[CVRuntime()], + workflow_outbound_interceptors=[CVOutbound()], + ) + + @runtime.activity(name='cv_probe') + def cv_probe(_ctx, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/act') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + flaky_call_count = [0] + + @runtime.activity(name='cv_flaky_probe') + def cv_flaky_probe(ctx, _=None): + before = current_trace.get() + flaky_call_count[0] += 1 + print(f'----------> flaky_call_count: {flaky_call_count[0]}') + if flaky_call_count[0] == 1: + # Fail first attempt to trigger retry + raise Exception('fail-once') + tok = current_trace.set(f'{before}/act-retry') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + @runtime.async_workflow(name='cv_child') + async def cv_child(ctx: AsyncWorkflowContext, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/child') if before else None + try: + act = await ctx.call_activity(cv_probe, input=None) + finally: + if tok is not None: + current_trace.reset(tok) + restored = current_trace.get() + return {'before': before, 'restored': restored, 'act': act} + + @runtime.async_workflow(name='cv_parent') + async def cv_parent(ctx: AsyncWorkflowContext, _=None): + from datetime import timedelta + + from dapr.ext.workflow import RetryPolicy + + top_before = current_trace.get() + child = await ctx.call_child_workflow(cv_child, input=None) + after_child = current_trace.get() + act = await ctx.call_activity(cv_probe, input=None) + after_act = current_trace.get() + act_retry = await ctx.call_activity( + cv_flaky_probe, + input=None, + retry_policy=RetryPolicy( + first_retry_interval=timedelta(seconds=0), max_number_of_attempts=3 + ), + ) + return { + 'before': top_before, + 'child': child, + 'act': act, + 'act_retry': act_retry, + 'after_child': after_child, + 'after_act': after_act, + } + + runtime.start() + client = wf_client(dapr_config, interceptors=[CVClient()]) + try: + time.sleep(2) + iid = f'cv-ctx-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=cv_parent, instance_id=iid) + purge = True + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Top-level activity sees parent trace context during execution + act = out.get('act') or {} + assert act.get('before') == 'wf-parent' + assert act.get('inner') == 'wf-parent/act' + assert act.get('after') == 'wf-parent' + # Child workflow's activity at least inherits parent context + child = out.get('child') or {} + child_act = child.get('act') or {} + assert child_act.get('before') == 'wf-parent' + assert child_act.get('inner') == 'wf-parent/act' + assert child_act.get('after') == 'wf-parent' + # Flaky activity retried: second attempt succeeds and returns with parent context + act_retry = out.get('act_retry') or {} + assert act_retry.get('before') == 'wf-parent' + assert act_retry.get('inner') == 'wf-parent/act-retry' + assert act_retry.get('after') == 'wf-parent' + finally: + if purge: + client.purge_workflow(iid, recursive=True) + runtime.shutdown() + + +def test_runtime_interceptor_shapes_async_input(dapr_sidecar, dapr_config): + class ShapeInput(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + data = request.input + # Mutate input passed to workflow + if isinstance(data, dict): + shaped = {**data, 'shaped': True} + else: + shaped = {'value': data, 'shaped': True} + request.input = shaped + return next(request) + + # Recreate runtime with interceptor wired in + runtime = wfr(dapr_config, runtime_interceptors=[ShapeInput()]) + + @runtime.async_workflow(name='wf_shape_input') + async def wf_shape_input(ctx: AsyncWorkflowContext, arg: dict | None = None): + # Verify shaped input is observed by the workflow + return arg + + runtime.start() + try: + client = wf_client(dapr_config) + iid = f'shape-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_shape_input, instance_id=iid, input={'x': 1}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + assert out.get('x') == 1 + assert out.get('shaped') is True + finally: + runtime.shutdown() + + +def test_runtime_interceptor_context_manager_with_async_workflow(dapr_sidecar, dapr_config): + """Test that context managers stay active during async workflow execution.""" + runtime = wfr(dapr_config) + + # Track when context enters and exits + context_state = {'entered': False, 'exited': False, 'workflow_ran': False} + + class ContextInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + # Wrapper generator to keep context manager alive + def wrapper(): + from contextlib import ExitStack + + with ExitStack(): + # Mark context as entered + context_state['entered'] = True + + # Get the workflow generator + gen = next(request) + + # Use yield from to keep context alive during execution + yield from gen + + # Context will exit after generator completes + context_state['exited'] = True + + return wrapper() + + runtime = wfr(dapr_config, runtime_interceptors=[ContextInterceptor()]) + + @runtime.async_workflow(name='wf_context_test') + async def wf_context_test(ctx: AsyncWorkflowContext, arg: dict | None = None): + context_state['workflow_ran'] = True + return {'result': 'ok'} + + runtime.start() + try: + client = wf_client(dapr_config) + iid = f'ctx-test-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_context_test, instance_id=iid, input={}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify context manager was active during workflow execution + assert context_state['entered'], 'Context should have been entered' + assert context_state['workflow_ran'], 'Workflow should have executed' + assert context_state['exited'], 'Context should have exited after completion' + finally: + runtime.shutdown() + + +def test_outbound_interceptor_can_modify_retry_policy(dapr_sidecar, dapr_config): + """Test that outbound interceptors can inspect and modify retry_policy and app_id fields.""" + captured_requests: dict[str, Any] = {} + + class CaptureAndModifyInterceptor(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + # Capture the request details + captured_requests['activity_name'] = request.activity_name + captured_requests['activity_retry_policy'] = request.retry_policy + + # Modify retry_policy and app_id if not set + retry_policy = request.retry_policy + if retry_policy is None: + retry_policy = RetryPolicy( + max_number_of_attempts=2, + first_retry_interval=timedelta(milliseconds=200), + max_retry_interval=timedelta(seconds=3), + ) + captured_requests['activity_retry_modified'] = True + + return next(dataclasses.replace(request, retry_policy=retry_policy)) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + # Capture the request details + captured_requests['child_workflow_name'] = request.workflow_name + captured_requests['child_retry_policy'] = request.retry_policy + + # Modify retry_policy if not set + retry_policy = request.retry_policy + if retry_policy is None: + retry_policy = RetryPolicy( + max_number_of_attempts=2, + first_retry_interval=timedelta(milliseconds=200), + max_retry_interval=timedelta(seconds=3), + ) + captured_requests['child_retry_modified'] = True + + return next(dataclasses.replace(request, retry_policy=retry_policy)) + + runtime = wfr(dapr_config, workflow_outbound_interceptors=[CaptureAndModifyInterceptor()]) + + @runtime.activity(name='test_activity') + def test_activity(ctx, arg): + return 'activity-result' + + @runtime.workflow(name='child_wf') + def child_wf(ctx, arg): + return 'child-result' + + @runtime.workflow(name='parent_wf') + def parent_wf(ctx, arg): + # Call activity without retry_policy or app_id + result1 = yield ctx.call_activity('test_activity', input={'x': 1}) + # Call child workflow without retry_policy or app_id + result2 = yield ctx.call_child_workflow('child_wf', input={'y': 2}) + return {'activity': result1, 'child': result2} + + runtime.start() + client = wf_client(dapr_config) + purge = False + try: + iid = f'retry-appid-test-{id(runtime)}' + client.schedule_new_workflow(workflow=parent_wf, instance_id=iid, input={}) + purge = True + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify interceptor captured and modified the requests + assert captured_requests['activity_name'] == 'test_activity' + assert captured_requests['activity_retry_policy'] is None + assert captured_requests.get('activity_retry_modified') is True + + assert captured_requests['child_workflow_name'] == 'child_wf' + assert captured_requests['child_retry_policy'] is None + assert captured_requests.get('child_retry_modified') is True + finally: + if purge: + client.purge_workflow(iid, recursive=True) + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_registration.py b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py new file mode 100644 index 000000000..b64502845 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py @@ -0,0 +1,136 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import inspect + +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.activities = {} + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeActivityContext: + def __init__(self): + self.orchestration_id = 'test-orch-id' + self.task_id = 1 + + +def test_activity_decorator_supports_async(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) # Simulate async work + return x + 2 + + # Ensure registered + reg = rt._WorkflowRuntime__worker._registry + assert 'async_act' in reg.activities + + # Verify the wrapper is async + wrapper = reg.activities['async_act'] + assert inspect.iscoroutinefunction(wrapper), 'Async activity wrapper should be a coroutine' + + # Call the wrapper and ensure it returns a coroutine that can be awaited + ctx = _FakeActivityContext() + coro = wrapper(ctx, 5) + assert inspect.iscoroutine(coro), 'Async wrapper should return a coroutine' + + # Run the coroutine and verify result + out = asyncio.run(coro) + assert out == 7 + + +def test_activity_decorator_supports_sync(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='sync_act') + def sync_act(ctx, x: int) -> int: + return x * 3 + + # Ensure registered + reg = rt._WorkflowRuntime__worker._registry + assert 'sync_act' in reg.activities + + # Verify the wrapper is sync + wrapper = reg.activities['sync_act'] + assert not inspect.iscoroutinefunction(wrapper), ( + 'Sync activity wrapper should not be a coroutine' + ) + + # Call the wrapper directly (no await needed) + ctx = _FakeActivityContext() + out = wrapper(ctx, 4) + assert out == 12 + + +def test_async_and_sync_activities_coexist(monkeypatch): + """Test that both async and sync activities can be registered in the same runtime.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='sync_act') + def sync_act(ctx, x: int) -> int: + return x * 2 + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) + return x + 10 + + # Ensure both registered + reg = rt._WorkflowRuntime__worker._registry + assert 'sync_act' in reg.activities + assert 'async_act' in reg.activities + + # Verify correct wrapper types + sync_wrapper = reg.activities['sync_act'] + async_wrapper = reg.activities['async_act'] + assert not inspect.iscoroutinefunction(sync_wrapper) + assert inspect.iscoroutinefunction(async_wrapper) + + # Verify both work correctly + ctx = _FakeActivityContext() + sync_result = sync_wrapper(ctx, 5) + assert sync_result == 10 + + async_result = asyncio.run(async_wrapper(ctx, 5)) + assert async_result == 15 diff --git a/ext/dapr-ext-workflow/tests/test_async_api_coverage.py b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py new file mode 100644 index 000000000..c4828e231 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-cov' + self._status = None + + def set_custom_status(self, status): + self._status = status + + def continue_as_new(self, new_input, *, save_events=False): + self._continued = (new_input, save_events) + + # methods used by awaitables + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + class _T: + pass + + return _T() + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + pass + + return _T() + + def create_timer(self, fire_at): + class _T: + pass + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + pass + + return _T() + + +def test_async_context_exposes_required_methods(): + base = FakeCtx() + ctx = AsyncWorkflowContext(base) + + # basic deterministic utils existence + assert isinstance(ctx.now(), datetime) + _ = ctx.random() + _ = ctx.uuid4() + + # pass-throughs + ctx.set_custom_status('ok') + assert base._status == 'ok' + ctx.continue_as_new({'foo': 1}, save_events=True) + assert getattr(base, '_continued', None) == ({'foo': 1}, True) + + # awaitable constructors do not raise + ctx.call_activity(lambda: None, input={'x': 1}) + ctx.call_child_workflow(lambda: None) + ctx.create_timer(1.0) + ctx.wait_for_external_event('go') + ctx.when_all([]) + ctx.when_any([]) diff --git a/ext/dapr-ext-workflow/tests/test_async_context.py b/ext/dapr-ext-workflow/tests/test_async_context.py new file mode 100644 index 000000000..1bc1bb1c4 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_context.py @@ -0,0 +1,196 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime, timedelta, timezone + +from dapr.ext.workflow import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.workflow_context import WorkflowContext + + +class DummyBaseCtx: + def __init__(self): + self.instance_id = 'abc-123' + # freeze a deterministic timestamp + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + self._custom_status = None + self._continued = None + self._metadata = None + + # Minimal orchestration API used by async awaitables + def create_timer(self, fire_at): + # In unit tests we don't run the task; just return a simple sentinel object + return object() + + def set_custom_status(self, s: str): + self._custom_status = s + + def continue_as_new(self, new_input, *, save_events: bool = False): + self._continued = (new_input, save_events) + + # Metadata parity + def set_metadata(self, md): + self._metadata = md + + def get_metadata(self): + return self._metadata + + @property + def execution_info(self): + return self._ei + + +def test_parity_properties_and_now(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + assert ctx.instance_id == 'abc-123' + assert ctx.current_utc_datetime == datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + # now() should mirror current_utc_datetime + assert ctx.now() == ctx.current_utc_datetime + + +def test_timer_accepts_float_and_timedelta(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + # Float should be interpreted as seconds and produce a SleepAwaitable + aw1 = ctx.create_timer(1.5) + # Timedelta should pass through + aw2 = ctx.create_timer(timedelta(seconds=2)) + + # We only assert types by duck-typing public attribute presence to avoid + # importing internal classes in tests + assert hasattr(aw1, '_ctx') and hasattr(aw1, '__await__') + assert hasattr(aw2, '_ctx') and hasattr(aw2, '__await__') + + +def test_wait_for_external_event_and_concurrency_factories(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + + evt = ctx.wait_for_external_event('go') + assert hasattr(evt, '__await__') + + # when_all/when_any return awaitables + a = ctx.create_timer(0.1) + b = ctx.create_timer(0.2) + + all_aw = ctx.when_all([a, b]) + any_aw = ctx.when_any([a, b]) + + for x in (all_aw, any_aw): + assert hasattr(x, '__await__') + + +def test_deterministic_utils_and_passthroughs(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + rnd = ctx.random() + # should behave like a random.Random-like object; test a stable first value + val = rnd.random() + # Just assert it is within (0,1) and stable across two calls to the seeded RNG instance + assert 0.0 < val < 1.0 + assert rnd.random() != val # next value changes + + uid = ctx.uuid4() + # Should be a UUID-like string representation + assert isinstance(str(uid), str) and len(str(uid)) >= 32 + + # passthroughs + ctx.set_custom_status('hello') + assert base._custom_status == 'hello' + + ctx.continue_as_new({'x': 1}, save_events=True) + assert base._continued == ({'x': 1}, True) + + +def test_async_metadata_api_and_execution_info(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + ctx.set_metadata({'k': 'v'}) + assert ctx._metadata == {'k': 'v'} + assert ctx.get_metadata() == {'k': 'v'} + # Note: execution_info is no longer directly accessible on AsyncWorkflowContext + # (it was removed from durabletask). Use DaprWorkflowContext for execution_info. + + +def test_async_outbound_metadata_plumbed_into_awaitables(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + a = ctx.call_activity(lambda: None, input=1, metadata={'m': 'n'}) + c = ctx.call_child_workflow(lambda c, x: None, input=2, metadata={'x': 'y'}) + # Introspect for test (internal attribute) + assert getattr(a, '_metadata', None) == {'m': 'n'} + assert getattr(c, '_metadata', None) == {'x': 'y'} + + +def test_async_parity_surface_exists(): + # Guard: ensure essential parity members exist + ctx = AsyncWorkflowContext(DummyBaseCtx()) + for name in ( + 'set_metadata', + 'get_metadata', + 'execution_info', + 'call_activity', + 'call_child_workflow', + 'continue_as_new', + ): + assert hasattr(ctx, name) + + +def test_public_api_parity_against_workflowcontext_abc(): + # Derive the required sync API surface from the ABC plus metadata/execution_info + required = { + name + for name, attr in WorkflowContext.__dict__.items() + if getattr(attr, '__isabstractmethod__', False) + } + required.update({'set_metadata', 'get_metadata', 'execution_info'}) + + # Async context must expose the same names + async_ctx = AsyncWorkflowContext(DummyBaseCtx()) + missing_in_async = [name for name in required if not hasattr(async_ctx, name)] + assert not missing_in_async, f'AsyncWorkflowContext missing: {missing_in_async}' + + # Sync context should also expose these names + class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'abc-123' + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + + def set_custom_status(self, s: str): + pass + + def create_timer(self, fire_at): + return object() + + def wait_for_external_event(self, name: str): + return object() + + def continue_as_new(self, new_input, *, save_events: bool = False): + pass + + def call_activity( + self, *, activity, input=None, retry_policy=None, app_id: str | None = None + ): + return object() + + def call_sub_orchestrator( + self, fn, *, input=None, instance_id=None, retry_policy=None, app_id: str | None = None + ): + return object() + + sync_ctx = DaprWorkflowContext(_FakeOrchCtx()) + missing_in_sync = [name for name in required if not hasattr(sync_ctx, name)] + assert not missing_in_sync, f'DaprWorkflowContext missing: {missing_in_sync}' diff --git a/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py new file mode 100644 index 000000000..f1df67202 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_workflow_decorator_detects_async_and_registers(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='async_wf') + async def async_wf(ctx: AsyncWorkflowContext, x: int) -> int: + # no awaits to keep simple + return x + 1 + + # ensure it was placed into registry + reg = rt._WorkflowRuntime__worker._registry + assert 'async_wf' in reg.orchestrators diff --git a/ext/dapr-ext-workflow/tests/test_deterministic.py b/ext/dapr-ext-workflow/tests/test_deterministic.py new file mode 100644 index 000000000..fa76f22fc --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_deterministic.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import datetime as _dt + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + +""" +Tests for deterministic helpers shared across workflow contexts. +""" + + +class _FakeBaseCtx: + def __init__(self, instance_id: str, dt: _dt.datetime): + self.instance_id = instance_id + self.current_utc_datetime = dt + + +def _fixed_dt(): + return _dt.datetime(2024, 1, 1) + + +def test_random_string_deterministic_across_instances_async(): + base = _FakeBaseCtx('iid-1', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + b_ctx = AsyncWorkflowContext(base) + a = a_ctx.random_string(16) + b = b_ctx.random_string(16) + assert a == b + + +def test_random_string_deterministic_across_context_types(): + base = _FakeBaseCtx('iid-2', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + s1 = a_ctx.random_string(12) + + # Minimal fake orchestration context for DaprWorkflowContext + d_ctx = DaprWorkflowContext(base) + s2 = d_ctx.random_string(12) + assert s1 == s2 + + +def test_random_string_respects_alphabet(): + base = _FakeBaseCtx('iid-3', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + s = ctx.random_string(20, alphabet='abc') + assert set(s).issubset(set('abc')) + + +def test_random_string_length_and_edge_cases(): + base = _FakeBaseCtx('iid-4', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + + assert ctx.random_string(0) == '' + + with pytest.raises(ValueError): + ctx.random_string(-1) + + with pytest.raises(ValueError): + ctx.random_string(5, alphabet='') diff --git a/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py new file mode 100644 index 000000000..7154e5b2b --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py @@ -0,0 +1,560 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from dapr.ext.workflow import ( + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowRuntime, +) + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _TracingInterceptor(RuntimeInterceptor): + """Interceptor that injects and restores trace context.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'wf_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'wf_trace_cleanup:{tracing_data}') + + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'act_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'act_trace_cleanup:{tracing_data}') + + return result + + +class _LoggingInterceptor(RuntimeInterceptor): + """Interceptor that logs workflow and activity execution.""" + + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + self.events.append(f'{self.label}:wf_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:wf_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:wf_error:{type(e).__name__}') + raise + + def execute_activity(self, request: ExecuteActivityRequest, next): + self.events.append(f'{self.label}:act_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:act_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:act_error:{type(e).__name__}') + raise + + +class _ValidationInterceptor(RuntimeInterceptor): + """Interceptor that validates inputs and outputs.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('wf_validation_failed') + raise ValueError('Invalid workflow input') + + self.events.append('wf_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, dict) and result.get('invalid_output'): + self.events.append('wf_output_validation_failed') + raise ValueError('Invalid workflow output') + + self.events.append('wf_output_validation_passed') + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('act_validation_failed') + raise ValueError('Invalid activity input') + + self.events.append('act_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, str) and 'invalid' in result: + self.events.append('act_output_validation_failed') + raise ValueError('Invalid activity output') + + self.events.append('act_output_validation_passed') + return result + + +def test_single_interceptor_workflow_execution(monkeypatch): + """Test single interceptor around workflow execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='simple') + def simple(ctx, x: int): + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['simple'] + result = orch(_make_orch_ctx(), 5) + + # For non-generator workflows, the result is returned directly + assert result == 10 + assert events == [ + 'log:wf_start:5', + 'log:wf_complete:10', + ] + + +def test_single_interceptor_activity_execution(monkeypatch): + """Test single interceptor around activity execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + result = act(_make_act_ctx(), 7) + + assert result == 14 + assert events == [ + 'log:act_start:7', + 'log:act_complete:14', + ] + + +def test_multiple_interceptors_execution_order(monkeypatch): + """Test multiple interceptors execute in correct order.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + outer_interceptor = _LoggingInterceptor(events, 'outer') + inner_interceptor = _LoggingInterceptor(events, 'inner') + + # First interceptor in list is outermost + rt = WorkflowRuntime(runtime_interceptors=[outer_interceptor, inner_interceptor]) + + @rt.workflow(name='ordered') + def ordered(ctx, x: int): + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['ordered'] + result = orch(_make_orch_ctx(), 3) + + assert result == 4 + # Outer interceptor enters first, exits last (stack semantics) + assert events == [ + 'outer:wf_start:3', + 'inner:wf_start:3', + 'inner:wf_complete:4', + 'outer:wf_complete:4', + ] + + +def test_tracing_interceptor_context_restoration(monkeypatch): + """Test tracing interceptor properly handles trace context.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + tracing_interceptor = _TracingInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[tracing_interceptor]) + + @rt.workflow(name='traced') + def traced(ctx, input_data): + # Workflow can access the trace context that was restored + return {'result': input_data.get('value', 0) * 2} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['traced'] + + # Input with tracing data + input_with_trace = {'value': 5, 'tracing': {'trace_id': 'abc123', 'span_id': 'def456'}} + + result = orch(_make_orch_ctx(), input_with_trace) + + assert result == {'result': 10} + assert events == [ + "wf_trace_restored:{'trace_id': 'abc123', 'span_id': 'def456'}", + "wf_trace_cleanup:{'trace_id': 'abc123', 'span_id': 'def456'}", + ] + + +def test_validation_interceptor_input_validation(monkeypatch): + """Test validation interceptor catches invalid inputs.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + validation_interceptor = _ValidationInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[validation_interceptor]) + + @rt.workflow(name='validated') + def validated(ctx, input_data): + return {'result': 'ok'} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['validated'] + + # Test valid input + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == {'result': 'ok'} + assert 'wf_validation_passed' in events + assert 'wf_output_validation_passed' in events + + # Test invalid input + events.clear() + + with pytest.raises(ValueError, match='Invalid workflow input'): + orch(_make_orch_ctx(), {'invalid': True}) + + assert 'wf_validation_failed' in events + + +def test_interceptor_error_handling_workflow(monkeypatch): + """Test interceptor properly handles workflow errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='error_wf') + def error_wf(ctx, x: int): + raise ValueError('workflow error') + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['error_wf'] + + with pytest.raises(ValueError, match='workflow error'): + orch(_make_orch_ctx(), 1) + + assert events == [ + 'log:wf_start:1', + 'log:wf_error:ValueError', + ] + + +def test_interceptor_error_handling_activity(monkeypatch): + """Test interceptor properly handles activity errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='error_act') + def error_act(ctx, x: int) -> int: + raise RuntimeError('activity error') + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['error_act'] + + with pytest.raises(RuntimeError, match='activity error'): + act(_make_act_ctx(), 5) + + assert events == [ + 'log:act_start:5', + 'log:act_error:RuntimeError', + ] + + +def test_async_workflow_with_interceptors(monkeypatch): + """Test interceptors work with async workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='async_wf') + async def async_wf(ctx, x: int): + # Simple async workflow + return x * 3 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['async_wf'] + gen_result = orch(_make_orch_ctx(), 4) + + # Async workflows return a generator that needs to be driven + with pytest.raises(StopIteration) as stop: + next(gen_result) + result = stop.value.value + + assert result == 12 + # The interceptor sees the generator being returned, not the final result + assert events[0] == 'log:wf_start:4' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_async_activity_with_interceptors(monkeypatch): + """Test interceptors work with async activities.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) # Simulate async work + return x * 4 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['async_act'] + # Async wrapper returns a coroutine that must be awaited + result = asyncio.run(act(_make_act_ctx(), 3)) + + assert result == 12 + assert events == [ + 'log:act_start:3', + 'log:act_complete:12', + ] + + +def test_generator_workflow_with_interceptors(monkeypatch): + """Test interceptors work with generator workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='gen_wf') + def gen_wf(ctx, x: int): + v1 = yield 'step1' + v2 = yield 'step2' + return (x, v1, v2) + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen_wf'] + gen_orch = orch(_make_orch_ctx(), 1) + + # Drive the generator + assert next(gen_orch) == 'step1' + assert gen_orch.send('result1') == 'step2' + with pytest.raises(StopIteration) as stop: + gen_orch.send('result2') + result = stop.value.value + + assert result == (1, 'result1', 'result2') + # For generator workflows, interceptor sees the generator being returned + assert events[0] == 'log:wf_start:1' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_interceptor_chain_with_early_return(monkeypatch): + """Test interceptor can modify or short-circuit execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ShortCircuitInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + events.append('short_circuit_check') + if isinstance(request.input, dict) and request.input.get('short_circuit'): + events.append('short_circuited') + return 'short_circuit_result' + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + logging_interceptor = _LoggingInterceptor(events, 'log') + short_circuit_interceptor = _ShortCircuitInterceptor() + + rt = WorkflowRuntime(runtime_interceptors=[short_circuit_interceptor, logging_interceptor]) + + @rt.workflow(name='maybe_short') + def maybe_short(ctx, input_data): + return 'normal_result' + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['maybe_short'] + + # Test normal execution + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == 'normal_result' + assert 'short_circuit_check' in events + assert 'log:wf_start' in str(events) + assert 'log:wf_complete' in str(events) + + # Test short-circuit execution + events.clear() + result = orch(_make_orch_ctx(), {'short_circuit': True}) + + assert result == 'short_circuit_result' + assert 'short_circuit_check' in events + assert 'short_circuited' in events + # Logging interceptor should not be called when short-circuited + assert 'log:wf_start' not in str(events) + + +def test_interceptor_input_transformation(monkeypatch): + """Test interceptor can transform inputs before execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _TransformInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Transform input by adding metadata + if isinstance(request.input, dict): + transformed_input = {**request.input, 'interceptor_metadata': 'added'} + new_input = ExecuteWorkflowRequest(ctx=request.ctx, input=transformed_input) + events.append(f'transformed_input:{transformed_input}') + return next(new_input) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + transform_interceptor = _TransformInterceptor() + rt = WorkflowRuntime(runtime_interceptors=[transform_interceptor]) + + @rt.workflow(name='transform_test') + def transform_test(ctx, input_data): + # Workflow should see the transformed input + return input_data + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['transform_test'] + result = orch(_make_orch_ctx(), {'original': 'value'}) + + # Result should include the interceptor metadata + assert result == {'original': 'value', 'interceptor_metadata': 'added'} + assert 'transformed_input:' in str(events) + + +def test_runtime_interceptor_can_shape_activity_result(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _ShapeResult(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + res = next(request) + return {'wrapped': res} + + rt = WorkflowRuntime(runtime_interceptors=[_ShapeResult()]) + + @rt.activity(name='echo') + def echo(_ctx, x): + return x + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['echo'] + out = act(_make_act_ctx(), 7) + assert out == {'wrapped': 7} diff --git a/ext/dapr-ext-workflow/tests/test_interceptors.py b/ext/dapr-ext-workflow/tests/test_interceptors.py new file mode 100644 index 000000000..9ba37287d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_interceptors.py @@ -0,0 +1,176 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from dapr.ext.workflow import RuntimeInterceptor, WorkflowRuntime + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +""" +Runtime interceptor chain tests for `WorkflowRuntime`. + +This suite intentionally uses a fake worker/registry to validate interceptor composition +without requiring a sidecar. It focuses on the "why" behind runtime interceptors: + +- Ensure `execute_workflow` and `execute_activity` hooks compose in order and are + invoked exactly once around workflow entry/activity execution. +- Cover both generator-based and async workflows, asserting the chain returns a + generator to the runtime (rather than iterating it), preserving send()/throw() + semantics during orchestration replay. +- Keep signal-to-noise high for failures in chain logic independent of gRPC/sidecar. + +These tests complement outbound/client interceptor tests and e2e tests by providing +fast, deterministic coverage of the chaining behavior and generator handling rules. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _RecorderInterceptor(RuntimeInterceptor): + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:wf_enter:{request.input!r}') + ret = next(request) + self.events.append(f'{self.label}:wf_ret_type:{ret.__class__.__name__}') + return ret + + def execute_activity(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:act_enter:{request.input!r}') + res = next(request) + self.events.append(f'{self.label}:act_exit:{res!r}') + return res + + +def test_generator_workflow_hooks_sequence(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='gen') + def gen(ctx, x: int): + v = yield 'A' + v2 = yield 'B' + return (x, v, v2) + + # Drive the registered orchestrator + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen'] + gen_driver = orch(_make_orch_ctx(), 10) + # Prime and run + assert next(gen_driver) == 'A' + assert gen_driver.send('ra') == 'B' + with pytest.raises(StopIteration) as stop: + gen_driver.send('rb') + result = stop.value.value + + assert result == (10, 'ra', 'rb') + # Interceptors run once around the workflow entry; they return a generator to the runtime + assert events[0] == 'mw:wf_enter:10' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_async_workflow_hooks_called(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='awf') + async def awf(ctx, x: int): + # No awaits to keep the driver simple + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['awf'] + gen_orch = orch(_make_orch_ctx(), 41) + with pytest.raises(StopIteration) as stop: + next(gen_orch) + result = stop.value.value + + assert result == 42 + # For async workflow, interceptor sees entry and a generator return type + assert events[0] == 'mw:wf_enter:41' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_activity_hooks_and_policy(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ExplodingActivity(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + raise RuntimeError('boom') + + def execute_workflow(self, request, next): # type: ignore[override] + return next(request) + + # Continue-on-error policy + rt = WorkflowRuntime( + runtime_interceptors=[_RecorderInterceptor(events, 'mw'), _ExplodingActivity()] + ) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + # Error in interceptor bubbles up + with pytest.raises(RuntimeError): + act(_make_act_ctx(), 5) diff --git a/ext/dapr-ext-workflow/tests/test_metadata_context.py b/ext/dapr-ext-workflow/tests/test_metadata_context.py new file mode 100644 index 000000000..1745b6be3 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_metadata_context.py @@ -0,0 +1,372 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Optional + +import pytest +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = datetime(2024, 1, 1) + self._custom_status = None + self.is_replaying = False + self.workflow_name = 'wf' + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + +def _drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +def test_client_schedule_metadata_envelope(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, + name, + *, + input=None, + instance_id=None, + start_at: Optional[datetime] = None, + reuse_id_policy=None, + ): # noqa: E501 + captured['name'] = name + captured['input'] = input + captured['instance_id'] = instance_id + captured['start_at'] = start_at + captured['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _InjectMetadata(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + # Add metadata without touching args + md = {'otel.trace_id': 't-123'} + new_request = ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + return next(new_request) + + client = DaprWorkflowClient(interceptors=[_InjectMetadata()]) + + def wf(ctx, x): + yield 'noop' + + wf.__name__ = 'meta_wf' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + env = captured['input'] + assert isinstance(env, dict) + assert '__dapr_meta__' in env and '__dapr_payload__' in env + assert env['__dapr_payload__'] == {'a': 1} + assert env['__dapr_meta__']['metadata']['otel.trace_id'] == 't-123' + + +def test_runtime_inbound_unwrap_and_metadata_visible(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _Recorder(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + seen['metadata'] = request.metadata + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + seen['act_metadata'] = request.metadata + return next(request) + + rt = WorkflowRuntime(runtime_interceptors=[_Recorder()]) + + @rt.workflow(name='unwrap') + def unwrap(ctx, x): + # x should be the original payload, not the envelope + assert x == {'hello': 'world'} + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['unwrap'] + envelope = { + '__dapr_meta__': {'v': 1, 'metadata': {'c': 'd'}}, + '__dapr_payload__': {'hello': 'world'}, + } + result = orch(_FakeOrchCtx(), envelope) + assert result == 'ok' + assert seen['metadata'] == {'c': 'd'} + + +def test_outbound_activity_and_child_wrap_metadata(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _AddActMeta(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + # Wrap returned args with metadata by returning a new CallActivityRequest + return next( + type(request)( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + app_id=request.app_id, + workflow_ctx=request.workflow_ctx, + metadata={'k': 'v'}, + ) + ) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + retry_policy=request.retry_policy, + app_id=request.app_id, + workflow_ctx=request.workflow_ctx, + metadata={'p': 'q'}, + ) + ) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_AddActMeta()]) + + @rt.workflow(name='parent') + def parent(ctx, x): + a = yield ctx.call_activity(lambda: None, input={'i': 1}) + b = yield ctx.call_child_workflow(lambda c, y: None, input={'j': 2}) + # Return both so we can assert envelopes surfaced through our fake driver + return a, b + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + # First yield: activity token received by driver; shape may be envelope or raw depending on adapter + t1 = gen.send(None) + assert hasattr(t1, '_v') + # Resume with any value; our fake driver ignores and loops + t2 = gen.send({'act': 'done'}) + assert hasattr(t2, '_v') + with pytest.raises(StopIteration) as stop: + gen.send({'child': 'done'}) + result = stop.value.value + # The result is whatever user returned; envelopes validated above + assert isinstance(result, tuple) and len(result) == 2 + + +def test_context_set_metadata_default_propagation(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + # No outbound interceptor needed; runtime will wrap using ctx.get_metadata() + rt = WorkflowRuntime() + + @rt.workflow(name='use_ctx_md') + def use_ctx_md(ctx, x): + # Set default metadata on context + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}) + # Return the raw yielded value for assertion + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['use_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + assert hasattr(yielded, '_v') + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'ctx' + + +def test_per_call_metadata_overrides_context(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='override_ctx_md') + def override_ctx_md(ctx, x): + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}, metadata={'k': 'per'}) + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['override_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'per' + + +def test_execution_info_workflow_and_activity(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + def act(ctx, x): + # activity inbound metadata and execution info available + md = ctx.get_metadata() + ei = ctx.execution_info + assert md == {'m': 'v'} + assert ei is not None and ei.inbound_metadata == {'m': 'v'} + # activity_name should reflect the registered name + assert ei.activity_name == 'act' + return x + + @rt.workflow(name='execinfo') + def execinfo(ctx, x): + # set default metadata + ctx.set_metadata({'m': 'v'}) + # workflow execution info available (minimal inbound only) + wi = ctx.execution_info + assert wi is not None and wi.inbound_metadata == {} + v = yield ctx.call_activity(act, input=42) + return v + + # register activity + rt.activity(name='act')(act) + orch = rt._WorkflowRuntime__worker._registry.orchestrators['execinfo'] + gen = orch(_FakeOrchCtx(), 7) + # drive one yield (call_activity) + gen.send(None) + # send back a value for activity result + with pytest.raises(StopIteration) as stop: + gen.send(42) + assert stop.value.value == 42 + + +def test_client_interceptor_can_shape_schedule_response(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + captured['name'] = name + return 'raw-id-123' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _ShapeId(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + rid = next(request) + return f'shaped:{rid}' + + client = DaprWorkflowClient(interceptors=[_ShapeId()]) + + def wf(ctx): + yield 'noop' + + wf.__name__ = 'shape_test' + iid = client.schedule_new_workflow(wf, input=None) + assert iid == 'shaped:raw-id-123' diff --git a/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py new file mode 100644 index 000000000..147f6d549 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py @@ -0,0 +1,298 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import durabletask.worker as worker_mod +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + self._continued_payload = None + self.workflow_attempt = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + # return input back for assertion through driver + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + def continue_as_new(self, new_request, *, save_events: bool = False): + # Record payload for assertions + self._continued_payload = new_request + + +def drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +class _InjectTrace(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + x = request.input + if x is None: + request = type(request)( + activity_name=request.activity_name, + input={'tracing': 'T'}, + retry_policy=request.retry_policy, + app_id=request.app_id, + ) + elif isinstance(x, dict): + out = dict(x) + out.setdefault('tracing', 'T') + request = type(request)( + activity_name=request.activity_name, + input=out, + retry_policy=request.retry_policy, + app_id=request.app_id, + ) + return next(request) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input={'child': request.input}, + instance_id=request.instance_id, + retry_policy=request.retry_policy, + app_id=request.app_id, + ) + ) + + +def test_outbound_activity_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + @rt.workflow(name='w') + def w(ctx, x): + # schedule an activity; runtime should pass transformed input to durable task + y = yield ctx.call_activity(lambda: None, input={'a': 1}) + return y['tracing'] + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'tracing': 'T', 'a': 1}) + assert out == 'T' + + +def test_outbound_child_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + def child(ctx, x): + yield 'noop' + + @rt.workflow(name='parent') + def parent(ctx, x): + y = yield ctx.call_child_workflow(child, input={'b': 2}) + return y + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'child': {'b': 2}}) + assert out == {'child': {'b': 2}} + + +def test_outbound_continue_as_new_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _InjectCAN(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault('x', '1') + request.metadata = md + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectCAN()]) + + @rt.workflow(name='w2') + def w2(ctx, x): + ctx.continue_as_new({'p': 1}, carryover_metadata=True) + return 'unreached' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w2'] + fake = _FakeOrchCtx() + _ = orch(fake, 0) + # Verify envelope contains injected metadata + assert isinstance(fake._continued_payload, dict) + meta = fake._continued_payload.get('__dapr_meta__') + payload = fake._continued_payload.get('__dapr_payload__') + assert isinstance(meta, dict) and isinstance(payload, dict) + assert meta.get('metadata', {}).get('x') == '1' + assert payload == {'p': 1} + + +def test_interceptor_called_for_string_activity_names(monkeypatch): + """Test that outbound interceptors are invoked for string-based activity names. + + Regression test: Previously, interceptors were only called for function activities, + not for string activity names (cross-app scenario). + """ + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + interceptor_calls = [] + + class TrackingInterceptor(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + # Track that interceptor was called with these parameters + interceptor_calls.append( + { + 'activity_name': request.activity_name, + 'app_id': request.app_id, + 'retry_policy': request.retry_policy, + } + ) + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[TrackingInterceptor()]) + + @rt.workflow(name='w_string_activity') + def w_string_activity(ctx, x): + # Call activity with STRING name and app_id (cross-app scenario) + # The activity doesn't need to exist for this test - we're just testing + # that the interceptor gets called + yield ctx.call_activity('remote_activity', input=x, app_id='remote-app') + return 'done' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w_string_activity'] + gen = orch(_FakeOrchCtx(), {'data': 1}) + drive(gen, 'mock-result') + + # Verify interceptor was called for string-based activity + assert len(interceptor_calls) == 1 + assert interceptor_calls[0]['activity_name'] == 'remote_activity' + assert interceptor_calls[0]['app_id'] == 'remote-app' + + +def test_interceptor_called_for_string_workflow_names(monkeypatch): + """Test that outbound interceptors are invoked for string-based workflow names. + + Regression test: Previously, interceptors were only called for function workflows, + not for string workflow names (cross-app scenario). + """ + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + interceptor_calls = [] + + class TrackingInterceptor(BaseWorkflowOutboundInterceptor): + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + # Track that interceptor was called with these parameters + interceptor_calls.append( + { + 'workflow_name': request.workflow_name, + 'app_id': request.app_id, + 'retry_policy': request.retry_policy, + 'instance_id': request.instance_id, + } + ) + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[TrackingInterceptor()]) + + @rt.workflow(name='w_string_workflow') + def w_string_workflow(ctx, x): + # Call child workflow with STRING name and app_id (cross-app scenario) + # The workflow doesn't need to exist for this test - we're just testing + # that the interceptor gets called + yield ctx.call_child_workflow( + 'remote_workflow', input=x, instance_id='test-id', app_id='remote-workflow-app' + ) + return 'done' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w_string_workflow'] + gen = orch(_FakeOrchCtx(), {'data': 1}) + drive(gen, 'mock-result') + + # Verify interceptor was called for string-based workflow + assert len(interceptor_calls) == 1 + assert interceptor_calls[0]['workflow_name'] == 'remote_workflow' + assert interceptor_calls[0]['app_id'] == 'remote-workflow-app' + assert interceptor_calls[0]['instance_id'] == 'test-id' diff --git a/ext/dapr-ext-workflow/tests/test_trace_fields.py b/ext/dapr-ext-workflow/tests/test_trace_fields.py new file mode 100644 index 000000000..03d38e1e3 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_trace_fields.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'wf-123' + self.current_utc_datetime = datetime(2025, 1, 1, tzinfo=timezone.utc) + self.is_replaying = False + self.workflow_name = 'wf_name' + self.parent_instance_id = 'parent-1' + self.history_event_sequence = 42 + self.trace_parent = '00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01' + self.trace_state = 'vendor=state' + self.orchestration_span_id = 'bbbbbbbbbbbbbbbb' + + +class _FakeActivityCtx: + def __init__(self): + self.orchestration_id = 'wf-123' + self.task_id = 7 + self.trace_parent = '00-cccccccccccccccccccccccccccccccc-dddddddddddddddd-01' + self.trace_state = 'v=1' + + +def test_workflow_execution_info_minimal(): + ei = WorkflowExecutionInfo(inbound_metadata={'k': 'v'}) + assert ei.inbound_metadata == {'k': 'v'} + + +def test_activity_execution_info_minimal(): + aei = ActivityExecutionInfo(inbound_metadata={'m': 'v'}, activity_name='act_name') + assert aei.inbound_metadata == {'m': 'v'} + + +def test_workflow_activity_context_execution_info_trace_fields(): + base = _FakeActivityCtx() + actx = WorkflowActivityContext(base) + aei = ActivityExecutionInfo(inbound_metadata={}, activity_name='act_name') + actx._set_execution_info(aei) + got = actx.execution_info + assert got is not None + assert got.inbound_metadata == {} diff --git a/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py new file mode 100644 index 000000000..35f933611 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py @@ -0,0 +1,171 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + RuntimeInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchestrationContext: + def __init__(self, *, is_replaying: bool = False): + self.instance_id = 'wf-1' + self.current_utc_datetime = datetime(2025, 1, 1) + self.is_replaying = is_replaying + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + +def _drive_generator(gen, returned_value): + # Prime to first yield; then drive + next(gen) + while True: + try: + gen.send(returned_value) + except StopIteration as stop: + return stop.value + + +def test_client_injects_tracing_on_schedule(monkeypatch): + import durabletask.client as client_mod + + # monkeypatch TaskHubGrpcClient to capture inputs + scheduled: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + scheduled['name'] = name + scheduled['input'] = input + scheduled['instance_id'] = instance_id + scheduled['start_at'] = start_at + scheduled['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _TracingClient(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': uuid.uuid4().hex} + if isinstance(request.input, dict) and 'tracing' not in request.input: + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + client = DaprWorkflowClient(interceptors=[_TracingClient()]) + + # We only need a callable with a __name__ for scheduling + def wf(ctx): + yield 'noop' + + wf.__name__ = 'inject_test' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + assert scheduled['name'] == 'inject_test' + assert isinstance(scheduled['input'], dict) + assert 'tracing' in scheduled['input'] + assert scheduled['input']['a'] == 1 + + +def test_runtime_restores_tracing_before_user_code(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _TracingRuntime(RuntimeInterceptor): + def execute_workflow(self, request, next): # type: ignore[override] + # no-op; real restoration is app concern; test just ensures input contains tracing + return next(request) + + def execute_activity(self, request, next): # type: ignore[override] + return next(request) + + class _TracingClient2(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': 't1'} + if isinstance(request.input, dict): + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + rt = WorkflowRuntime( + runtime_interceptors=[_TracingRuntime()], + ) + + @rt.workflow(name='w') + def w(ctx, x): + # The tracing should already be present in input + assert isinstance(x, dict) + assert 'tracing' in x + seen['trace'] = x['tracing'] + yield 'noop' + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + # Orchestrator input will have tracing injected via outbound when scheduled as a child or via client + # Here, we directly pass the input simulating schedule with tracing present + gen = orch(_FakeOrchestrationContext(), {'hello': 'world', 'tracing': {'trace_id': 't1'}}) + out = _drive_generator(gen, returned_value='noop') + assert out == 'ok' + assert seen['trace']['trace_id'] == 't1' diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index bf18cd689..f367a43fd 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -37,7 +37,10 @@ class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() listOrchestrators.clear() - mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + self.patcher = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self.patcher.start() self.runtime_options = WorkflowRuntime() if hasattr(self.mock_client_wf, '_dapr_alternate_name'): del self.mock_client_wf.__dict__['_dapr_alternate_name'] @@ -48,6 +51,11 @@ def setUp(self): if hasattr(self.mock_client_activity, '_activity_registered'): del self.mock_client_activity.__dict__['_activity_registered'] + def tearDown(self): + """Stop the mock patch to prevent interference with other tests.""" + self.patcher.stop() + mock.patch.stopall() # Ensure all patches are stopped + def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') diff --git a/ext/dapr-ext-workflow/tests/test_workflow_util.py b/ext/dapr-ext-workflow/tests/test_workflow_util.py index 28e92e6c5..c1b980eda 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_util.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_util.py @@ -1,3 +1,16 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import unittest from unittest.mock import patch @@ -7,6 +20,7 @@ class DaprWorkflowUtilTest(unittest.TestCase): + @patch.object(settings, 'DAPR_GRPC_ENDPOINT', '') def test_get_address_default(self): expected = f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' self.assertEqual(expected, getAddress()) diff --git a/pyproject.toml b/pyproject.toml index 0378a8c8f..ed9fb11a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,13 @@ target-version = "py310" line-length = 100 fix = true -extend-exclude = [".github", "dapr/proto"] - +extend-exclude = [ + ".github", + "dapr/proto", + "*_pb2.py", + "*_pb2_grpc.py", + "examples/**/.venv", +] [tool.ruff.lint] select = [ "I", # isort diff --git a/tox.ini b/tox.ini index 7c31dd8a3..0df064b3b 100644 --- a/tox.ini +++ b/tox.ini @@ -2,29 +2,46 @@ skipsdist = True minversion = 3.10.0 envlist = - py{310,311,312,313} + py{310,311,312,313,314} ruff, mypy, +# TODO: switch runner to uv (tox-uv plugin) +runner = virtualenv [testenv] setenv = PYTHONDONTWRITEBYTECODE=1 deps = -rdev-requirements.txt +package = editable commands = coverage run -m unittest discover -v ./tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-workflow/tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-grpc/tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-fastapi/tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-langgraph/tests - coverage run -a -m unittest discover -v ./ext/flask_dapr/tests + # ext/dapr-ext-workflow uses pytest-based tests + # only workflow has e2e tests (with daprd sidecars initiated as fixtures during tests) + !e2e: coverage run -a -m pytest -q -m "not e2e" ext/dapr-ext-workflow/tests + e2e: coverage run -a -m pytest -q -m "e2e" ext/dapr-ext-workflow/tests + + !e2e: coverage run -a -m unittest discover -v ./ext/dapr-ext-grpc/tests + !e2e: coverage run -a -m unittest discover -v ./ext/dapr-ext-fastapi/tests + !e2e: coverage run -a -m unittest discover -v ./ext/dapr-ext-langgraph/tests + !e2e: coverage run -a -m unittest discover -v ./ext/flask_dapr/tests coverage xml commands_pre = - pip3 install -e {toxinidir}/ - pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ - pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ - pip3 install -e {toxinidir}/ext/dapr-ext-fastapi/ - pip3 install -e {toxinidir}/ext/dapr-ext-langgraph/ - pip3 install -e {toxinidir}/ext/flask_dapr/ + # TODO: remove this before merging (after durable task is merged) + {envpython} -m pip install -e {toxinidir}/../durabletask-python/ + + {envpython} -m pip install -e {toxinidir}/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-workflow/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-grpc/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-fastapi/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-langgraph/ + {envpython} -m pip install -e {toxinidir}/ext/flask_dapr/ +# allow for overriding sidecar ports +pass_env = DAPR_GRPC_ENDPOINT,DAPR_HTTP_ENDPOINT,DAPR_RUNTIME_HOST,DAPR_GRPC_PORT,DAPR_HTTP_PORT,DURABLETASK_GRPC_ENDPOINT + +[flake8] +extend-exclude = .tox,venv,build,dist,dapr/proto,examples/**/.venv +ignore = E203,E501,W503,E701,E704,F821 +max-line-length = 100 [testenv:ruff] basepython = python3 @@ -63,6 +80,9 @@ commands = ./validate.sh jobs ./validate.sh ../ commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ + pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ @@ -97,6 +117,9 @@ deps = -rdev-requirements.txt commands = mypy --config-file mypy.ini commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ + pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/