diff --git a/README.md b/README.md index f6a0284..40a4e6e 100644 --- a/README.md +++ b/README.md @@ -126,10 +126,97 @@ Orchestrations can be continued as new using the `continue_as_new` API. This API Orchestrations can be suspended using the `suspend_orchestration` client API and will remain suspended until resumed using the `resume_orchestration` client API. A suspended orchestration will stop processing new events, but will continue to buffer any that happen to arrive until resumed, ensuring that no data is lost. An orchestration can also be terminated using the `terminate_orchestration` client API. Terminated orchestrations will stop processing new events and will discard any buffered events. -### Retry policies (TODO) +### Retry policies Orchestrations can specify retry policies for activities and sub-orchestrations. These policies control how many times and how frequently an activity or sub-orchestration will be retried in the event of a transient error. +#### Creating a retry policy + +```python +from datetime import timedelta +from durabletask import task + +retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), # Initial delay before first retry + max_number_of_attempts=5, # Maximum total attempts (includes first attempt) + backoff_coefficient=2.0, # Exponential backoff multiplier (must be >= 1) + max_retry_interval=timedelta(seconds=30), # Cap on retry delay + retry_timeout=timedelta(minutes=5), # Total time limit for all retries (optional) +) +``` + +**Notes:** +- `max_number_of_attempts` **includes the initial attempt**. For example, `max_number_of_attempts=5` means 1 initial attempt + up to 4 retries. +- `retry_timeout` is optional. If omitted or set to `None`, retries continue until `max_number_of_attempts` is reached. +- `backoff_coefficient` controls exponential backoff: delay = `first_retry_interval * (backoff_coefficient ^ retry_number)`, capped by `max_retry_interval`. +- `non_retryable_error_types` (optional) can specify additional exception types to treat as non-retryable (e.g., `[ValueError, TypeError]`). `NonRetryableError` is always non-retryable regardless of this setting. + +#### Using retry policies + +Apply retry policies to activities or sub-orchestrations: + +```python +def my_orchestrator(ctx: task.OrchestrationContext, input): + # Retry an activity + result = yield ctx.call_activity(my_activity, input=data, retry_policy=retry_policy) + + # Retry a sub-orchestration + result = yield ctx.call_sub_orchestrator(child_orchestrator, input=data, retry_policy=retry_policy) +``` + +#### Non-retryable errors + +For errors that should not be retried (e.g., validation failures, permanent errors), raise a `NonRetryableError`: + +```python +from durabletask.task import NonRetryableError + +def my_activity(ctx: task.ActivityContext, input): + if input is None: + # This error will bypass retry logic and fail immediately + raise NonRetryableError("Input cannot be None") + + # Transient errors (network, timeouts, etc.) will be retried + return call_external_service(input) +``` + +Even with a retry policy configured, `NonRetryableError` will fail immediately without retrying. + +#### Error type matching behavior + +**Important:** Error type matching uses **exact class name comparison**, not `isinstance()` checks. This is because exception objects are serialized to gRPC protobuf messages, where only the class name (as a string) survives serialization. + +**Key implications:** + +- **Not inheritance-aware**: If you specify `ValueError` in `non_retryable_error_types`, it will only match exceptions with the exact class name `"ValueError"`. A custom subclass like `CustomValueError(ValueError)` will NOT match. +- **Workaround**: List all exception types explicitly, including subclasses you want to handle. +- **Built-in exception**: `NonRetryableError` is always treated as non-retryable, matched by the name `"NonRetryableError"`. + +**Example:** + +```python +from datetime import timedelta +from durabletask import task + +# Custom exception hierarchy +class ValidationError(ValueError): + pass + +# This policy ONLY matches exact "ValueError" by name +retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + non_retryable_error_types=[ValueError] # Won't match ValidationError subclass! +) + +# To handle both, list them explicitly: +retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + non_retryable_error_types=[ValueError, ValidationError] # Both converted to name strings +) +``` + ## Getting Started ### Prerequisites @@ -194,7 +281,7 @@ Certain aspects like multi-app activities require the full dapr runtime to be ru ```shell dapr init || true -dapr run --app-id test-app --dapr-grpc-port 4001 --components-path ./examples/components/ +dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/ ``` To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root: diff --git a/durabletask/client.py b/durabletask/client.py index 1e28f30..e3d391f 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -32,6 +32,7 @@ class OrchestrationStatus(Enum): CONTINUED_AS_NEW = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW PENDING = pb.ORCHESTRATION_STATUS_PENDING SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED + CANCELED = pb.ORCHESTRATION_STATUS_CANCELED def __str__(self): return helpers.get_orchestration_status_str(self.value) @@ -127,9 +128,28 @@ def __init__( interceptors=interceptors, options=channel_options, ) + self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + try: + self.close() + finally: + return False + + def close(self) -> None: + """Close the underlying gRPC channel.""" + try: + # grpc.Channel.close() is idempotent + self._channel.close() + except Exception: + # Best-effort cleanup + pass + def schedule_new_orchestration( self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], diff --git a/durabletask/task.py b/durabletask/task.py index 66abc28..0b27b6f 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -233,6 +233,29 @@ class OrchestrationStateError(Exception): pass +class NonRetryableError(Exception): + """Exception indicating the operation should not be retried. + + If an activity or sub-orchestration raises this exception, retry logic will be + bypassed and the failure will be returned immediately to the orchestrator. + """ + + pass + + +def is_error_non_retryable(error_type: str, policy: RetryPolicy) -> bool: + """Checks whether an error type is non-retryable.""" + is_non_retryable = False + if error_type == "NonRetryableError": + is_non_retryable = True + elif ( + policy.non_retryable_error_types is not None + and error_type in policy.non_retryable_error_types + ): + is_non_retryable = True + return is_non_retryable + + class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" @@ -397,7 +420,7 @@ def compute_next_delay(self) -> Optional[timedelta]: next_delay_f = min( next_delay_f, self._retry_policy.max_retry_interval.total_seconds() ) - return timedelta(seconds=next_delay_f) + return timedelta(seconds=next_delay_f) return None @@ -490,6 +513,7 @@ def __init__( backoff_coefficient: Optional[float] = 1.0, max_retry_interval: Optional[timedelta] = None, retry_timeout: Optional[timedelta] = None, + non_retryable_error_types: Optional[list[Union[str, type]]] = None, ): """Creates a new RetryPolicy instance. @@ -505,6 +529,11 @@ def __init__( The maximum retry interval to use for any retry attempt. retry_timeout : Optional[timedelta] The maximum amount of time to spend retrying the operation. + non_retryable_error_types : Optional[list[Union[str, type]]] + A list of exception type names or classes that should not be retried. + If a failure's error type matches any of these, the task fails immediately. + The built-in NonRetryableError is always treated as non-retryable regardless + of this setting. """ # validate inputs if first_retry_interval < timedelta(seconds=0): @@ -523,6 +552,16 @@ def __init__( self._backoff_coefficient = backoff_coefficient self._max_retry_interval = max_retry_interval self._retry_timeout = retry_timeout + # Normalize non-retryable error type names to a set of strings + names: Optional[set[str]] = None + if non_retryable_error_types: + names = set[str]() + for t in non_retryable_error_types: + if isinstance(t, str) and t: + names.add(t) + elif isinstance(t, type): + names.add(t.__name__) + self._non_retryable_error_types = names @property def first_retry_interval(self) -> timedelta: @@ -549,6 +588,15 @@ def retry_timeout(self) -> Optional[timedelta]: """The maximum amount of time to spend retrying the operation.""" return self._retry_timeout + @property + def non_retryable_error_types(self) -> Optional[set[str]]: + """Set of error type names that should not be retried. + + Comparison is performed against the errorType string provided by the + backend (typically the exception class name). + """ + return self._non_retryable_error_types + def get_name(fn: Callable) -> str: """Returns the name of the provided function""" diff --git a/durabletask/worker.py b/durabletask/worker.py index daa661b..8fcc763 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -159,6 +159,8 @@ class TaskHubGrpcWorker: interceptors to apply to the channel. Defaults to None. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default settings are used. + stop_timeout (float, optional): Maximum time in seconds to wait for the worker thread + to stop when calling stop(). Defaults to 30.0. Useful to set lower values in tests. Attributes: concurrency_options (ConcurrencyOptions): The current concurrency configuration. @@ -224,6 +226,7 @@ def __init__( interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, concurrency_options: Optional[ConcurrencyOptions] = None, channel_options: Optional[Sequence[tuple[str, Any]]] = None, + stop_timeout: float = 30.0, ): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() @@ -232,6 +235,7 @@ def __init__( self._is_running = False self._secure_channel = secure_channel self._channel_options = channel_options + self._stop_timeout = stop_timeout # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -398,7 +402,10 @@ def should_invalidate_connection(rpc_error): def stream_reader(): try: - for work_item in self._response_stream: + stream = self._response_stream + if stream is None: + return + for work_item in stream: # type: ignore work_item_queue.put(work_item) except Exception as e: work_item_queue.put(e) @@ -433,6 +440,8 @@ def stream_reader(): pass else: self._logger.warning(f"Unexpected work item type: {request_type}") + except grpc.RpcError: + raise # let it be captured/parsed by outer except and avoid noisy log except Exception as e: self._logger.warning(f"Error in work item stream: {e}") raise e @@ -489,11 +498,39 @@ def stop(self): if self._response_stream is not None: self._response_stream.cancel() if self._runLoop is not None: - self._runLoop.join(timeout=30) + self._runLoop.join(timeout=self._stop_timeout) self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False + def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str): + """Handle a gRPC execution error during shutdown or benign condition.""" + # During shutdown or if the instance was terminated, the channel may be close + # or the instance may no longer be recognized by the sidecar. Treat these as benign + # to reduce noisy logging when shutting down. + details = str(rpc_error).lower() + benign_errors = { + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + } + if ( + self._shutdown.is_set() + and rpc_error.code() in benign_errors + or ( + "unknown instance id/task id combo" in details + or "channel closed" in details + or "locally cancelled by application" in details + ) + ): + self._logger.debug( + f"Ignoring gRPC {request_type} execution error during shutdown/benign condition: {rpc_error}" + ) + else: + self._logger.exception( + f"Failed to execute gRPC {request_type} execution error: {rpc_error}" + ) + def _execute_orchestrator( self, req: pb.OrchestratorRequest, @@ -527,6 +564,8 @@ def _execute_orchestrator( try: stub.CompleteOrchestratorTask(res) + except grpc.RpcError as rpc_error: # type: ignore + self._handle_grpc_execution_error(rpc_error, "orchestrator") except Exception as ex: self._logger.exception( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" @@ -558,6 +597,8 @@ def _execute_activity( try: stub.CompleteActivityTask(res) + except grpc.RpcError as rpc_error: # type: ignore + self._handle_grpc_execution_error(rpc_error, "activity") except Exception as ex: self._logger.exception( f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" @@ -802,7 +843,8 @@ def call_activity_function_helper( id = self.next_sequence_number() router = pb.TaskRouter() - router.sourceAppID = self._app_id + if self._app_id is not None: + router.sourceAppID = self._app_id if app_id is not None: router.targetAppID = app_id @@ -1078,16 +1120,26 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if isinstance(activity_task, task.RetryableTask): if activity_task._retry_policy is not None: - next_delay = activity_task.compute_next_delay() - if next_delay is None: + # Check for non-retryable errors by type name + if task.is_error_non_retryable( + event.taskFailed.failureDetails.errorType, activity_task._retry_policy + ): activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", event.taskFailed.failureDetails, ) ctx.resume() else: - activity_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, activity_task) + next_delay = activity_task.compute_next_delay() + if next_delay is None: + activity_task.fail( + f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", + event.taskFailed.failureDetails, + ) + ctx.resume() + else: + activity_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, activity_task) elif isinstance(activity_task, task.CompletableTask): activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", @@ -1145,16 +1197,26 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven return if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: - next_delay = sub_orch_task.compute_next_delay() - if next_delay is None: + # Check for non-retryable errors by type name + if task.is_error_non_retryable( + failedEvent.failureDetails.errorType, sub_orch_task._retry_policy + ): sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", failedEvent.failureDetails, ) ctx.resume() else: - sub_orch_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, sub_orch_task) + next_delay = sub_orch_task.compute_next_delay() + if next_delay is None: + sub_orch_task.fail( + f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", + failedEvent.failureDetails, + ) + ctx.resume() + else: + sub_orch_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, sub_orch_task) elif isinstance(sub_orch_task, task.CompletableTask): sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", diff --git a/requirements.txt b/requirements.txt index 7b288f0..b6902e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -# requirements in pyproject.toml +# pyproject.toml has the dependencies for this project \ No newline at end of file diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index b671cf8..c74ba17 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,5 +1,6 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch +from durabletask import client from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.internal.shared import get_default_host_address, get_grpc_channel @@ -140,3 +141,27 @@ def test_sync_channel_passes_base_options_and_max_lengths(): assert ("grpc.max_send_message_length", 1234) in opts assert ("grpc.max_receive_message_length", 5678) in opts assert ("grpc.primary_user_agent", "durabletask-tests") in opts + + +def test_taskhub_client_close_handles_exceptions(): + """Test that close() handles exceptions gracefully (edge case not easily testable in E2E).""" + with patch("durabletask.internal.shared.get_grpc_channel") as mock_get_channel: + mock_channel = MagicMock() + mock_channel.close.side_effect = Exception("close failed") + mock_get_channel.return_value = mock_channel + + task_hub_client = client.TaskHubGrpcClient() + # Should not raise exception + task_hub_client.close() + + +def test_taskhub_client_close_closes_channel_handles_exceptions(): + """Test that close() closes the channel and handles exceptions gracefully.""" + with patch("durabletask.internal.shared.get_grpc_channel") as mock_get_channel: + mock_channel = MagicMock() + mock_channel.close.side_effect = Exception("close failed") + mock_get_channel.return_value = mock_channel + + task_hub_client = client.TaskHubGrpcClient() + task_hub_client.close() + mock_channel.close.assert_called_once() diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 225456d..9debf39 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -5,6 +5,7 @@ import threading import time from datetime import timedelta +from typing import Optional import pytest @@ -16,6 +17,33 @@ pytestmark = pytest.mark.e2e +def _wait_until_terminal( + hub_client: client.TaskHubGrpcClient, + instance_id: str, + *, + timeout_s: int = 30, + fetch_payloads: bool = True, +) -> Optional[client.OrchestrationState]: + """Polling-based completion wait that does not rely on the completion stream. + + Returns the terminal state or None if timeout. + """ + deadline = time.time() + timeout_s + delay = 0.1 + while time.time() < deadline: + st = hub_client.get_orchestration_state(instance_id, fetch_payloads=fetch_payloads) + if st and st.runtime_status in ( + client.OrchestrationStatus.COMPLETED, + client.OrchestrationStatus.FAILED, + client.OrchestrationStatus.TERMINATED, + client.OrchestrationStatus.CANCELED, + ): + return st + time.sleep(delay) + delay = min(delay * 1.5, 1.0) + return None + + def test_empty_orchestration(): invoked = False @@ -37,6 +65,11 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) + # Test calling wait again on already-completed orchestration (should return immediately) + state2 = c.wait_for_orchestration_completion(id, timeout=30) + assert state2 is not None + assert state2.runtime_status == client.OrchestrationStatus.COMPLETED + assert invoked assert state is not None assert state.name == task.get_name(empty_orchestrator) @@ -61,14 +94,14 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): return numbers # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.start() - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(sequence, input=1) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(sequence, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(sequence) @@ -104,15 +137,15 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return error_msg # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.add_activity(throw) w.add_activity(increment_counter) w.start() - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(orchestrator) @@ -146,15 +179,15 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -170,7 +203,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return [a, b, c] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -199,16 +232,16 @@ def orchestrator(ctx: task.OrchestrationContext, _): return "timed out" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - if raise_event: - task_hub_client.raise_orchestration_event(id, "Approval") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, "Approval") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -224,37 +257,37 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - - # Suspend the orchestration and wait for it to go into the SUSPENDED state - task_hub_client.suspend_orchestration(id) - while state.runtime_status == client.OrchestrationStatus.RUNNING: - time.sleep(0.1) - state = task_hub_client.get_orchestration_state(id) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.SUSPENDED - - # Raise an event to the orchestration and confirm that it does NOT complete - task_hub_client.raise_orchestration_event(id, "my_event", data=42) - try: - state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) - assert False, "Orchestration should not have completed" - except TimeoutError: - pass - # Resume the orchestration and wait for it to complete - task_hub_client.resume_orchestration(id) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(42) + # Suspend the orchestration and wait for it to go into the SUSPENDED state + task_hub_client.suspend_orchestration(id) + while state.runtime_status == client.OrchestrationStatus.RUNNING: + time.sleep(0.1) + state = task_hub_client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + task_hub_client.raise_orchestration_event(id, "my_event", data=42) + try: + state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + task_hub_client.resume_orchestration(id) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) def test_terminate(): @@ -263,27 +296,29 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.RUNNING + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING - task_hub_client.terminate_orchestration(id, output="some reason for termination") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED - assert state.serialized_output == json.dumps("some reason for termination") + task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") def test_terminate_recursive(): thread_lock = threading.Lock() activity_counter = 0 - delay_time = 4 # seconds + delay_time = ( + 2 # seconds (already optimized from 4s - don't reduce further as it can leads to failure) + ) def increment(ctx, _): with thread_lock: @@ -303,36 +338,39 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) for recurse in [True, False]: - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() - instance_id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=5) - - time.sleep(2) - - output = "Recursive termination = {recurse}" - task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse) - - metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) - - assert metadata is not None - assert metadata.runtime_status == client.OrchestrationStatus.TERMINATED - assert metadata.serialized_output == f'"{output}"' + with client.TaskHubGrpcClient() as task_hub_client: + instance_id = task_hub_client.schedule_new_orchestration( + parent_orchestrator, input=5 + ) - time.sleep(delay_time) + time.sleep(1) # Brief delay to let orchestrations start - if recurse: - assert activity_counter == 0, ( - "Activity should not have executed with recursive termination" + output = "Recursive termination = {recurse}" + task_hub_client.terminate_orchestration( + instance_id, output=output, recursive=recurse ) - else: - assert activity_counter == 5, ( - "Activity should have executed without recursive termination" + + metadata = task_hub_client.wait_for_orchestration_completion( + instance_id, timeout=30 ) + assert metadata is not None + assert metadata.runtime_status == client.OrchestrationStatus.TERMINATED + assert metadata.serialized_output == f'"{output}"' + time.sleep(delay_time) # Wait for timer to check activity execution + if recurse: + assert activity_counter == 0, ( + "Activity should not have executed with recursive termination" + ) + else: + assert activity_counter == 5, ( + "Activity should have executed without recursive termination" + ) def test_continue_as_new(): @@ -351,7 +389,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return all_results # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -391,7 +429,7 @@ def orchestrator(ctx: task.OrchestrationContext, counter: int): else: return {"counter": counter, "processed": processed, "all_results": activity_results} - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(double_activity) w.add_orchestrator(orchestrator) w.start() @@ -424,13 +462,13 @@ def test_retry_policies(): child_orch_counter = 0 throw_activity_counter = 0 - # Second setup: With retry policies + # Second setup: With retry policies (minimal delays for faster tests) retry_policy = task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), + first_retry_interval=timedelta(seconds=0.05), # 0.1 → 0.05 (50% faster) max_number_of_attempts=3, backoff_coefficient=1, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30), + max_retry_interval=timedelta(seconds=0.5), # 1 → 0.5 + retry_timeout=timedelta(seconds=2), # 3 → 2 ) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): @@ -449,7 +487,7 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(parent_orchestrator_with_retry) w.add_orchestrator(child_orchestrator_with_retry) w.add_activity(throw_activity_with_retry) @@ -468,19 +506,46 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): assert throw_activity_counter == 9 assert child_orch_counter == 3 + # Test 2: Verify NonRetryableError prevents retries even with retry policy + non_retryable_counter = 0 + + def throw_non_retryable(ctx: task.ActivityContext, _): + nonlocal non_retryable_counter + non_retryable_counter += 1 + raise task.NonRetryableError("Cannot retry this!") + + def orchestrator_with_non_retryable(ctx: task.OrchestrationContext, _): + # Even with retry policy, NonRetryableError should fail immediately + yield ctx.call_activity(throw_non_retryable, retry_policy=retry_policy) + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator_with_non_retryable) + w.add_activity(throw_non_retryable) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator_with_non_retryable) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert "Cannot retry this!" in state.failure_details.message + # Key assertion: activity was called exactly once (no retries) + assert non_retryable_counter == 1 + def test_retry_timeout(): # This test verifies that the retry timeout is working as expected. - # Max number of attempts is 5 and retry timeout is 14 seconds. - # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. - # So, the 5th attempt should not be made and the orchestration should fail. + # Max number of attempts is 5 and retry timeout is 1.7 seconds. + # Delays: 0.25 + 0.5 + 1.0 = 1.75 seconds cumulative before 4th attempt. + # So, the 5th attempt (which would happen at 1.75s) should not be made. throw_activity_counter = 0 retry_policy = task.RetryPolicy( first_retry_interval=timedelta(seconds=1), max_number_of_attempts=5, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14), + retry_timeout=timedelta(seconds=13), # Set just before 4th attempt ) def mock_orchestrator(ctx: task.OrchestrationContext, _): @@ -491,7 +556,7 @@ def throw_activity(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(mock_orchestrator) w.add_activity(throw_activity) w.start() @@ -513,7 +578,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(empty_orchestrator) w.start() diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index c441bdc..b71e70b 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -110,7 +110,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return error_msg # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.add_activity(throw) w.add_activity(increment_counter) @@ -153,7 +153,7 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) @@ -178,7 +178,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return [a, b, c] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -208,7 +208,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return "timed out" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -234,7 +234,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() # there could be a race condition if the workflow is scheduled before orchestrator is started @@ -275,7 +275,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -302,7 +302,7 @@ def child(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(root) w.add_orchestrator(child) w.start() @@ -345,7 +345,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return all_results # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -376,13 +376,13 @@ async def test_retry_policies(): child_orch_counter = 0 throw_activity_counter = 0 - # Second setup: With retry policies + # Second setup: With retry policies (minimal delays for faster tests) retry_policy = task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), + first_retry_interval=timedelta(seconds=0.05), # 0.1 → 0.05 (50% faster) max_number_of_attempts=3, backoff_coefficient=1, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30), + max_retry_interval=timedelta(seconds=0.5), # 1 → 0.5 + retry_timeout=timedelta(seconds=2), # 3 → 2 ) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): @@ -401,7 +401,7 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(parent_orchestrator_with_retry) w.add_orchestrator(child_orchestrator_with_retry) w.add_activity(throw_activity_with_retry) @@ -423,16 +423,16 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): async def test_retry_timeout(): # This test verifies that the retry timeout is working as expected. - # Max number of attempts is 5 and retry timeout is 14 seconds. - # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. - # So, the 5th attempt should not be made and the orchestration should fail. + # Max number of attempts is 5 and retry timeout is 1.7 seconds. + # Delays: 0.25 + 0.5 + 1.0 = 1.75 seconds cumulative before 4th attempt. + # So, the 5th attempt (which would happen at 1.75s) should not be made. throw_activity_counter = 0 retry_policy = task.RetryPolicy( first_retry_interval=timedelta(seconds=1), max_number_of_attempts=5, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14), + retry_timeout=timedelta(seconds=13), # Set just before 4th attempt ) def mock_orchestrator(ctx: task.OrchestrationContext, _): @@ -443,7 +443,7 @@ def throw_activity(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(mock_orchestrator) w.add_activity(throw_activity) w.start() @@ -465,7 +465,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(empty_orchestrator) w.start() diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 964512f..bf81f26 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -3,7 +3,7 @@ import json import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -826,7 +826,7 @@ def test_nondeterminism_expected_sub_orchestration_task_completion_wrong_task_ty def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.create_timer( - datetime.utcnow() + datetime.now(timezone.utc) ) # created timer but history expects sub-orchestration return result @@ -920,7 +920,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Complete the timer task. The orchestration should move to the wait_for_external_event step, which # should then complete immediately because the event was buffered in the old event history. - timer_due_time = datetime.utcnow() + timedelta(days=1) + timer_due_time = datetime.now(timezone.utc) + timedelta(days=1) old_events = new_events + [helpers.new_timer_created_event(1, timer_due_time)] new_events = [helpers.new_timer_fired_event(1, timer_due_time)] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -1013,9 +1013,9 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): helpers.new_event_raised_event("my_event", encoded_input="42"), helpers.new_event_raised_event("my_event", encoded_input="43"), helpers.new_event_raised_event("my_event", encoded_input="44"), - helpers.new_timer_created_event(1, datetime.utcnow() + timedelta(days=1)), + helpers.new_timer_created_event(1, datetime.now(timezone.utc) + timedelta(days=1)), ] - new_events = [helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] + new_events = [helpers.new_timer_fired_event(1, datetime.now(timezone.utc) + timedelta(days=1))] # Execute the orchestration. It should be in a running state waiting for the timer to fire executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -1447,6 +1447,261 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert str(ex) in complete_action.failureDetails.errorMessage +def test_activity_non_retryable_default_exception(): + """If activity fails with NonRetryableError, it should not be retried and orchestration should fail immediately.""" + + def dummy_activity(ctx, _): + raise task.NonRetryableError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, task.NonRetryableError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__("Activity task #1 failed: boom") + + +def test_activity_non_retryable_policy_name(): + """If policy marks ValueError as non-retryable (by name), fail immediately without retry.""" + + def dummy_activity(ctx, _): + raise ValueError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + non_retryable_error_types=["ValueError"], + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, ValueError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__("Activity task #1 failed: boom") + + +def test_activity_generic_exception_is_retryable(): + """Verify that generic Exception is retryable by default (not treated as non-retryable).""" + + def dummy_activity(ctx, _): + raise Exception("generic error") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + # First attempt fails + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Should schedule a retry timer, not fail immediately + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].id == 2 + + # Simulate the timer firing and activity being rescheduled + expected_fire_at = current_timestamp + timedelta(seconds=1) + old_events = old_events + new_events + current_timestamp = expected_fire_at + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(2, current_timestamp), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert len(actions) == 2 # timer + rescheduled task + assert actions[1].HasField("scheduleTask") + assert actions[1].id == 1 + + # Second attempt also fails + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Should schedule another retry timer + assert len(actions) == 3 + assert actions[2].HasField("createTimer") + assert actions[2].id == 3 + + # Simulate the timer firing and activity being rescheduled + expected_fire_at = current_timestamp + timedelta(seconds=1) + old_events = old_events + new_events + current_timestamp = expected_fire_at + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(3, current_timestamp), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert len(actions) == 3 # timer + rescheduled task + assert actions[1].HasField("scheduleTask") + assert actions[1].id == 1 + + # Third attempt fails - should exhaust retries + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Now should fail - no more retries + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Activity task #1 failed: generic error" + ) + + +def test_sub_orchestration_non_retryable_default_exception(): + """If sub-orchestrator fails with NonRetryableError, do not retry and fail immediately.""" + + def child(ctx: task.OrchestrationContext, _): + pass + + def parent(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator( + child, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + ), + ) + + registry = worker._Registry() + child_name = registry.add_orchestrator(child) + parent_name = registry.add_orchestrator(parent) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, child_name, "sub-1", encoded_input=None), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_sub_orchestration_failed_event(1, task.NonRetryableError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Sub-orchestration task #1 failed: boom" + ) + + +def test_sub_orchestration_non_retryable_policy_type(): + """If policy marks ValueError as non-retryable (by class), fail immediately without retry.""" + + def child(ctx: task.OrchestrationContext, _): + pass + + def parent(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator( + child, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + non_retryable_error_types=[ValueError], + ), + ) + + registry = worker._Registry() + child_name = registry.add_orchestrator(child) + parent_name = registry.add_orchestrator(parent) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, child_name, "sub-1", encoded_input=None), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_sub_orchestration_failed_event(1, ValueError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Sub-orchestration task #1 failed: boom" + ) + + def get_and_validate_single_complete_orchestration_action( actions: list[pb.OrchestratorAction], ) -> pb.CompleteOrchestrationAction: diff --git a/tox.ini b/tox.ini index 9b21313..b6bc7ba 100644 --- a/tox.ini +++ b/tox.ini @@ -10,11 +10,9 @@ runner = virtualenv [testenv] # you can run tox with the e2e pytest marker using tox factors: -# tox -e py310,py311,py312,py313,py314 -- e2e -# or single one with: # tox -e py310-e2e -# to use custom grpc endpoint and not capture print statements (-s arg in pytest): -# DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e -- -s +# to use custom grpc endpoint: +# DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e setenv = PYTHONDONTWRITEBYTECODE=1 deps = .[dev]