Skip to content

Commit 366b2f6

Browse files
committed
add non-retryable errors to activities, helpers for shutdown, deterministic functions similar to dotnet
Signed-off-by: Filinto Duran <[email protected]>
1 parent 7f89f6a commit 366b2f6

File tree

16 files changed

+1593
-132
lines changed

16 files changed

+1593
-132
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Certain aspects like multi-app activities require the full dapr runtime to be ru
194194
```shell
195195
dapr init || true
196196

197-
dapr run --app-id test-app --dapr-grpc-port 4001 --components-path ./examples/components/
197+
dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/
198198
```
199199

200200
To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root:

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code

durabletask/client.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,28 @@ def __init__(
127127
interceptors=interceptors,
128128
options=channel_options,
129129
)
130+
self._channel = channel
130131
self._stub = stubs.TaskHubSidecarServiceStub(channel)
131132
self._logger = shared.get_logger("client", log_handler, log_formatter)
132133

134+
def __enter__(self):
135+
return self
136+
137+
def __exit__(self, exc_type, exc, tb):
138+
try:
139+
self.close()
140+
finally:
141+
return False
142+
143+
def close(self) -> None:
144+
"""Close the underlying gRPC channel."""
145+
try:
146+
# grpc.Channel.close() is idempotent
147+
self._channel.close()
148+
except Exception:
149+
# Best-effort cleanup
150+
pass
151+
133152
def schedule_new_orchestration(
134153
self,
135154
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
@@ -188,10 +207,59 @@ def wait_for_orchestration_completion(
188207
) -> Optional[OrchestrationState]:
189208
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
190209
try:
191-
grpc_timeout = None if timeout == 0 else timeout
192-
self._logger.info(
193-
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
194-
)
210+
# gRPC timeout mapping (pytest unit tests may pass None explicitly)
211+
grpc_timeout = None if (timeout is None or timeout == 0) else timeout
212+
213+
# If timeout is None or 0, skip pre-checks/polling and call server-side wait directly
214+
if timeout is None or timeout == 0:
215+
self._logger.info(
216+
f"Waiting {'indefinitely' if not timeout else f'up to {timeout}s'} for instance '{instance_id}' to complete."
217+
)
218+
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
219+
req, timeout=grpc_timeout
220+
)
221+
state = new_orchestration_state(req.instanceId, res)
222+
return state
223+
224+
# For positive timeout, best-effort pre-check and short polling to avoid long server waits
225+
try:
226+
# First check if the orchestration is already completed
227+
current_state = self.get_orchestration_state(
228+
instance_id, fetch_payloads=fetch_payloads
229+
)
230+
if current_state and current_state.runtime_status in [
231+
OrchestrationStatus.COMPLETED,
232+
OrchestrationStatus.FAILED,
233+
OrchestrationStatus.TERMINATED,
234+
]:
235+
return current_state
236+
237+
# Poll for completion with exponential backoff to handle eventual consistency
238+
import time
239+
240+
poll_timeout = min(timeout, 10)
241+
poll_start = time.time()
242+
poll_interval = 0.1
243+
244+
while time.time() - poll_start < poll_timeout:
245+
current_state = self.get_orchestration_state(
246+
instance_id, fetch_payloads=fetch_payloads
247+
)
248+
249+
if current_state and current_state.runtime_status in [
250+
OrchestrationStatus.COMPLETED,
251+
OrchestrationStatus.FAILED,
252+
OrchestrationStatus.TERMINATED,
253+
]:
254+
return current_state
255+
256+
time.sleep(poll_interval)
257+
poll_interval = min(poll_interval * 1.5, 1.0) # Exponential backoff, max 1s
258+
except Exception:
259+
# Ignore pre-check/poll issues (e.g., mocked stubs in unit tests) and fall back
260+
pass
261+
262+
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to complete.")
195263
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
196264
req, timeout=grpc_timeout
197265
)

durabletask/deterministic.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
"""
15+
Deterministic utilities for Durable Task workflows (async and generator).
16+
17+
This module provides deterministic alternatives to non-deterministic Python
18+
functions, ensuring workflow replay consistency across different executions.
19+
It is shared by both the asyncio authoring model and the generator-based model.
20+
"""
21+
22+
import hashlib
23+
import random
24+
import string as _string
25+
import uuid
26+
from collections.abc import Sequence
27+
from dataclasses import dataclass
28+
from datetime import datetime
29+
from typing import Optional, Protocol, TypeVar, runtime_checkable
30+
31+
32+
@dataclass
33+
class DeterminismSeed:
34+
"""Seed data for deterministic operations."""
35+
36+
instance_id: str
37+
orchestration_unix_ts: int
38+
39+
def to_int(self) -> int:
40+
"""Convert seed to integer for PRNG initialization."""
41+
combined = f"{self.instance_id}:{self.orchestration_unix_ts}"
42+
hash_bytes = hashlib.sha256(combined.encode("utf-8")).digest()
43+
return int.from_bytes(hash_bytes[:8], byteorder="big")
44+
45+
46+
def derive_seed(instance_id: str, orchestration_time: datetime) -> int:
47+
"""
48+
Derive a deterministic seed from instance ID and orchestration time.
49+
"""
50+
ts = int(orchestration_time.timestamp())
51+
return DeterminismSeed(instance_id=instance_id, orchestration_unix_ts=ts).to_int()
52+
53+
54+
def deterministic_random(instance_id: str, orchestration_time: datetime) -> random.Random:
55+
"""
56+
Create a deterministic random number generator.
57+
"""
58+
seed = derive_seed(instance_id, orchestration_time)
59+
return random.Random(seed)
60+
61+
62+
def deterministic_uuid4(rnd: random.Random) -> uuid.UUID:
63+
"""
64+
Generate a deterministic UUID4 using the provided random generator.
65+
66+
Note: This is deprecated in favor of deterministic_uuid_v5 which matches
67+
the .NET implementation for cross-language compatibility.
68+
"""
69+
bytes_ = bytes(rnd.randrange(0, 256) for _ in range(16))
70+
bytes_list = list(bytes_)
71+
bytes_list[6] = (bytes_list[6] & 0x0F) | 0x40 # Version 4
72+
bytes_list[8] = (bytes_list[8] & 0x3F) | 0x80 # Variant bits
73+
return uuid.UUID(bytes=bytes(bytes_list))
74+
75+
76+
def deterministic_uuid_v5(instance_id: str, current_datetime: datetime, counter: int) -> uuid.UUID:
77+
"""
78+
Generate a deterministic UUID v5 matching the .NET implementation.
79+
80+
This implementation matches the durabletask-dotnet NewGuid() method:
81+
https://github.com/microsoft/durabletask-dotnet/blob/main/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
82+
83+
Args:
84+
instance_id: The orchestration instance ID.
85+
current_datetime: The current orchestration datetime (frozen during replay).
86+
counter: The per-call counter (starts at 0 on each replay).
87+
88+
Returns:
89+
A deterministic UUID v5 that will be the same across replays.
90+
"""
91+
# DNS namespace UUID - same as .NET DnsNamespaceValue
92+
namespace = uuid.UUID("9e952958-5e33-4daf-827f-2fa12937b875")
93+
94+
# Build name matching .NET format: instanceId_datetime_counter
95+
# Using isoformat() which produces ISO 8601 format similar to .NET's ToString("o")
96+
name = f"{instance_id}_{current_datetime.isoformat()}_{counter}"
97+
98+
# Generate UUID v5 (SHA-1 based, matching .NET)
99+
return uuid.uuid5(namespace, name)
100+
101+
102+
@runtime_checkable
103+
class DeterministicContextProtocol(Protocol):
104+
"""Protocol for contexts that provide deterministic operations."""
105+
106+
@property
107+
def instance_id(self) -> str: ...
108+
109+
@property
110+
def current_utc_datetime(self) -> datetime: ...
111+
112+
113+
class DeterministicContextMixin:
114+
"""
115+
Mixin providing deterministic helpers for workflow contexts.
116+
117+
Assumes the inheriting class exposes `instance_id` and `current_utc_datetime` attributes.
118+
119+
This implementation matches the .NET durabletask SDK approach with an explicit
120+
counter for UUID generation that resets on each replay.
121+
"""
122+
123+
def __init__(self, *args, **kwargs):
124+
"""Initialize the mixin with a UUID counter."""
125+
super().__init__(*args, **kwargs)
126+
# Counter for deterministic UUID generation (matches .NET newGuidCounter)
127+
# This counter resets to 0 on each replay, ensuring determinism
128+
self._uuid_counter: int = 0
129+
130+
def now(self) -> datetime:
131+
"""Return orchestration time (deterministic UTC)."""
132+
value = self.current_utc_datetime # type: ignore[attr-defined]
133+
assert isinstance(value, datetime)
134+
return value
135+
136+
def random(self) -> random.Random:
137+
"""Return a PRNG seeded deterministically from instance id and orchestration time."""
138+
rnd = deterministic_random(
139+
self.instance_id, # type: ignore[attr-defined]
140+
self.current_utc_datetime, # type: ignore[attr-defined]
141+
)
142+
# Mark as deterministic for sandbox detector whitelisting of bound methods
143+
try:
144+
rnd._dt_deterministic = True
145+
except Exception:
146+
pass
147+
return rnd
148+
149+
def uuid4(self) -> uuid.UUID:
150+
"""
151+
Return a deterministically generated UUID v5 with explicit counter.
152+
https://www.sohamkamani.com/uuid-versions-explained/#v5-non-random-uuids
153+
154+
This matches the .NET implementation's NewGuid() method which uses:
155+
- Instance ID
156+
- Current UTC datetime (frozen during replay)
157+
- Per-call counter (resets to 0 on each replay)
158+
159+
The counter ensures multiple calls produce different UUIDs while maintaining
160+
determinism across replays.
161+
"""
162+
# Lazily initialize counter if not set by __init__ (for compatibility)
163+
if not hasattr(self, "_uuid_counter"):
164+
self._uuid_counter = 0
165+
166+
result = deterministic_uuid_v5(
167+
self.instance_id, # type: ignore[attr-defined]
168+
self.current_utc_datetime, # type: ignore[attr-defined]
169+
self._uuid_counter,
170+
)
171+
self._uuid_counter += 1
172+
return result
173+
174+
def new_guid(self) -> uuid.UUID:
175+
"""Alias for uuid4 for API parity with other SDKs."""
176+
return self.uuid4()
177+
178+
def random_string(self, length: int, *, alphabet: Optional[str] = None) -> str:
179+
"""Return a deterministically generated random string of the given length."""
180+
if length < 0:
181+
raise ValueError("length must be non-negative")
182+
chars = alphabet if alphabet is not None else (_string.ascii_letters + _string.digits)
183+
if not chars:
184+
raise ValueError("alphabet must not be empty")
185+
rnd = self.random()
186+
size = len(chars)
187+
return "".join(chars[rnd.randrange(0, size)] for _ in range(length))
188+
189+
def random_int(self, min_value: int = 0, max_value: int = 2**31 - 1) -> int:
190+
"""Return a deterministic random integer in the specified range."""
191+
if min_value > max_value:
192+
raise ValueError("min_value must be <= max_value")
193+
rnd = self.random()
194+
return rnd.randint(min_value, max_value)
195+
196+
T = TypeVar("T")
197+
198+
def random_choice(self, sequence: Sequence[T]) -> T:
199+
"""Return a deterministic random element from a non-empty sequence."""
200+
if not sequence:
201+
raise IndexError("Cannot choose from empty sequence")
202+
rnd = self.random()
203+
return rnd.choice(sequence)

durabletask/internal/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_logger(
102102
# Add a default log handler if none is provided
103103
if log_handler is None:
104104
log_handler = logging.StreamHandler()
105-
log_handler.setLevel(logging.INFO)
105+
log_handler.setLevel(logging.DEBUG)
106106
logger.handlers.append(log_handler)
107107

108108
# Set a default log formatter to our handler if none is provided

durabletask/task.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,16 @@ class OrchestrationStateError(Exception):
233233
pass
234234

235235

236+
class NonRetryableError(Exception):
237+
"""Exception indicating the operation should not be retried.
238+
239+
If an activity or sub-orchestration raises this exception, retry logic will be
240+
bypassed and the failure will be returned immediately to the orchestrator.
241+
"""
242+
243+
pass
244+
245+
236246
class Task(ABC, Generic[T]):
237247
"""Abstract base class for asynchronous tasks in a durable orchestration."""
238248

@@ -397,7 +407,7 @@ def compute_next_delay(self) -> Optional[timedelta]:
397407
next_delay_f = min(
398408
next_delay_f, self._retry_policy.max_retry_interval.total_seconds()
399409
)
400-
return timedelta(seconds=next_delay_f)
410+
return timedelta(seconds=next_delay_f)
401411

402412
return None
403413

@@ -490,6 +500,7 @@ def __init__(
490500
backoff_coefficient: Optional[float] = 1.0,
491501
max_retry_interval: Optional[timedelta] = None,
492502
retry_timeout: Optional[timedelta] = None,
503+
non_retryable_error_types: Optional[list[Union[str, type]]] = None,
493504
):
494505
"""Creates a new RetryPolicy instance.
495506
@@ -505,6 +516,11 @@ def __init__(
505516
The maximum retry interval to use for any retry attempt.
506517
retry_timeout : Optional[timedelta]
507518
The maximum amount of time to spend retrying the operation.
519+
non_retryable_error_types : Optional[list[Union[str, type]]]
520+
A list of exception type names or classes that should not be retried.
521+
If a failure's error type matches any of these, the task fails immediately.
522+
The built-in NonRetryableError is always treated as non-retryable regardless
523+
of this setting.
508524
"""
509525
# validate inputs
510526
if first_retry_interval < timedelta(seconds=0):
@@ -523,6 +539,17 @@ def __init__(
523539
self._backoff_coefficient = backoff_coefficient
524540
self._max_retry_interval = max_retry_interval
525541
self._retry_timeout = retry_timeout
542+
# Normalize non-retryable error type names to a set of strings
543+
names: Optional[set[str]] = None
544+
if non_retryable_error_types:
545+
names = set()
546+
for t in non_retryable_error_types:
547+
if isinstance(t, str):
548+
if t:
549+
names.add(t)
550+
elif isinstance(t, type):
551+
names.add(t.__name__)
552+
self._non_retryable_error_types = names
526553

527554
@property
528555
def first_retry_interval(self) -> timedelta:
@@ -549,6 +576,15 @@ def retry_timeout(self) -> Optional[timedelta]:
549576
"""The maximum amount of time to spend retrying the operation."""
550577
return self._retry_timeout
551578

579+
@property
580+
def non_retryable_error_types(self) -> Optional[set[str]]:
581+
"""Set of error type names that should not be retried.
582+
583+
Comparison is performed against the errorType string provided by the
584+
backend (typically the exception class name).
585+
"""
586+
return self._non_retryable_error_types
587+
552588

553589
def get_name(fn: Callable) -> str:
554590
"""Returns the name of the provided function"""

0 commit comments

Comments
 (0)