Skip to content

Commit 776253a

Browse files
committed
cleanup/feedback
Signed-off-by: Filinto Duran <[email protected]>
1 parent cb7c6e9 commit 776253a

21 files changed

+503
-421
lines changed

durabletask/aio/__init__.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .compatibility import OrchestrationContextProtocol, ensure_compatibility
2828

2929
# Core context and driver
30-
from .context import AsyncWorkflowContext, WorkflowInfo
30+
from .context import AsyncWorkflowContext
3131
from .driver import CoroutineOrchestratorRunner, WorkflowFunction
3232

3333
# Sandbox and error handling
@@ -38,20 +38,12 @@
3838
WorkflowTimeoutError,
3939
WorkflowValidationError,
4040
)
41-
from .sandbox import (
42-
SandboxMode,
43-
_NonDeterminismDetector,
44-
sandbox_best_effort,
45-
sandbox_off,
46-
sandbox_scope,
47-
sandbox_strict,
48-
)
41+
from .sandbox import SandboxMode, _NonDeterminismDetector
4942

5043
__all__ = [
5144
"AsyncTaskHubGrpcClient",
5245
# Core classes
5346
"AsyncWorkflowContext",
54-
"WorkflowInfo",
5547
"CoroutineOrchestratorRunner",
5648
"WorkflowFunction",
5749
# Deterministic utilities
@@ -73,11 +65,7 @@
7365
"SwallowExceptionAwaitable",
7466
"gather",
7567
# Sandbox and utilities
76-
"sandbox_scope",
7768
"SandboxMode",
78-
"sandbox_off",
79-
"sandbox_best_effort",
80-
"sandbox_strict",
8169
"_NonDeterminismDetector",
8270
# Compatibility protocol
8371
"OrchestrationContextProtocol",

durabletask/aio/awaitables.py

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def __await__(self) -> Generator[Any, Any, List[TOutput]]:
378378
class WhenAnyAwaitable(AwaitableBase[task.Task[Any]]):
379379
"""Awaitable for when_any operations (wait for any task to complete)."""
380380

381-
__slots__ = ("_tasks_like",)
381+
__slots__ = ("_originals", "_underlying")
382382

383383
def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]]]):
384384
"""
@@ -388,33 +388,33 @@ def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]
388388
tasks_like: Iterable of awaitables or tasks to wait for
389389
"""
390390
super().__init__()
391-
self._tasks_like = list(tasks_like)
391+
self._originals = list(tasks_like)
392+
# Defer conversion to avoid issues with incomplete mocks and coroutine reuse
393+
self._underlying: Optional[List[task.Task[Any]]] = None
394+
395+
def _ensure_underlying(self) -> List[task.Task[Any]]:
396+
"""Lazily convert originals to tasks, caching the result."""
397+
if self._underlying is None:
398+
self._underlying = []
399+
for a in self._originals:
400+
if isinstance(a, AwaitableBase):
401+
self._underlying.append(a._to_task())
402+
elif isinstance(a, task.Task):
403+
self._underlying.append(a)
404+
else:
405+
raise TypeError("when_any expects AwaitableBase or durabletask.task.Task")
406+
return self._underlying
392407

393408
def _to_task(self) -> task.Task[Any]:
394409
"""Convert to a when_any task."""
395-
underlying: List[task.Task[Any]] = []
396-
for a in self._tasks_like:
397-
if isinstance(a, AwaitableBase):
398-
underlying.append(a._to_task())
399-
elif isinstance(a, task.Task):
400-
underlying.append(a)
401-
else:
402-
raise TypeError("when_any expects AwaitableBase or durabletask.task.Task")
403-
return cast(task.Task[Any], task.when_any(underlying))
410+
return cast(task.Task[Any], task.when_any(self._ensure_underlying()))
404411

405412
def __await__(self) -> Generator[Any, Any, Any]:
406413
"""Return a proxy that compares equal to the original item and exposes get_result()."""
407-
when_any_task = self._to_task()
414+
underlying = self._ensure_underlying()
415+
when_any_task = task.when_any(underlying)
408416
completed = yield when_any_task
409417

410-
# Build underlying mapping original -> underlying task
411-
underlying: List[task.Task[Any]] = []
412-
for a in self._tasks_like:
413-
if isinstance(a, AwaitableBase):
414-
underlying.append(a._to_task())
415-
elif isinstance(a, task.Task):
416-
underlying.append(a)
417-
418418
class _CompletedProxy:
419419
__slots__ = ("_original", "_completed")
420420

@@ -431,20 +431,28 @@ def get_result(self) -> Any:
431431
return self._completed.get_result()
432432
return getattr(self._completed, "result", None)
433433

434+
@property
435+
def __dict__(self) -> dict[str, Any]:
436+
"""Expose a dict-like view for compatibility with user code."""
437+
return {
438+
"_original": self._original,
439+
"_completed": self._completed,
440+
}
441+
434442
def __repr__(self) -> str: # pragma: no cover
435443
return f"<WhenAnyCompleted proxy for {self._original!r}>"
436444

437445
# If the runtime returned a non-task sentinel (e.g., tests), assume first item won
438446
if not isinstance(completed, task.Task):
439-
return _CompletedProxy(self._tasks_like[0], completed)
447+
return _CompletedProxy(self._originals[0], completed)
440448

441449
# Map completed task back to the original item and return proxy
442-
for original, under in zip(self._tasks_like, underlying, strict=False):
450+
for original, under in zip(self._originals, underlying, strict=False):
443451
if completed == under:
444452
return _CompletedProxy(original, completed)
445453

446454
# Fallback proxy; treat the first as original
447-
return _CompletedProxy(self._tasks_like[0], completed)
455+
return _CompletedProxy(self._originals[0], completed)
448456

449457

450458
class WhenAnyResultAwaitable(AwaitableBase[tuple[int, Any]]):
@@ -454,7 +462,7 @@ class WhenAnyResultAwaitable(AwaitableBase[tuple[int, Any]]):
454462
This is useful when you need to know which task completed first, not just its result.
455463
"""
456464

457-
__slots__ = ("_tasks_like", "_awaitables")
465+
__slots__ = ("_originals", "_underlying")
458466

459467
def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]]]):
460468
"""
@@ -464,41 +472,37 @@ def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]
464472
tasks_like: Iterable of awaitables or tasks to wait for
465473
"""
466474
super().__init__()
467-
self._tasks_like = list(tasks_like)
468-
self._awaitables = self._tasks_like # Alias for compatibility
475+
self._originals = list(tasks_like)
476+
# Defer conversion to avoid issues with incomplete mocks and coroutine reuse
477+
self._underlying: Optional[List[task.Task[Any]]] = None
478+
479+
def _ensure_underlying(self) -> List[task.Task[Any]]:
480+
"""Lazily convert originals to tasks, caching the result."""
481+
if self._underlying is None:
482+
self._underlying = []
483+
for a in self._originals:
484+
if isinstance(a, AwaitableBase):
485+
self._underlying.append(a._to_task())
486+
elif isinstance(a, task.Task):
487+
self._underlying.append(a)
488+
else:
489+
raise TypeError(
490+
"when_any_with_result expects AwaitableBase or durabletask.task.Task"
491+
)
492+
return self._underlying
469493

470494
def _to_task(self) -> task.Task[Any]:
471495
"""Convert to a when_any task with result tracking."""
472-
underlying: List[task.Task[Any]] = []
473-
for a in self._tasks_like:
474-
if isinstance(a, AwaitableBase):
475-
underlying.append(a._to_task())
476-
elif isinstance(a, task.Task):
477-
underlying.append(a)
478-
else:
479-
raise TypeError(
480-
"when_any_with_result expects AwaitableBase or durabletask.task.Task"
481-
)
482-
483-
# Use when_any and then determine which task completed
484-
when_any_task = task.when_any(underlying)
485-
return cast(task.Task[Any], when_any_task)
496+
return cast(task.Task[Any], task.when_any(self._ensure_underlying()))
486497

487498
def __await__(self) -> Generator[Any, Any, tuple[int, Any]]:
488499
"""Override to provide index + result tuple."""
489-
t = self._to_task()
490-
completed_task = yield t
491-
492-
# Find which task completed by comparing results
493-
underlying_tasks: List[task.Task[Any]] = []
494-
for a in self._tasks_like:
495-
if isinstance(a, AwaitableBase):
496-
underlying_tasks.append(a._to_task())
497-
elif isinstance(a, task.Task):
498-
underlying_tasks.append(a)
500+
underlying = self._ensure_underlying()
501+
when_any_task = task.when_any(underlying)
502+
completed_task = yield when_any_task
499503

500504
# The completed_task should match one of our underlying tasks
501-
for i, underlying_task in enumerate(underlying_tasks):
505+
for i, underlying_task in enumerate(underlying):
502506
if underlying_task == completed_task:
503507
return (i, completed_task.result if hasattr(completed_task, "result") else None)
504508

durabletask/aio/compatibility.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,6 @@ def workflow_name(self) -> Optional[str]:
5454
"""Get the orchestrator name/type for this instance."""
5555
...
5656

57-
@property
58-
def parent_instance_id(self) -> Optional[str]:
59-
"""Get the parent orchestration ID if this is a sub-orchestration."""
60-
...
61-
62-
@property
63-
def history_event_sequence(self) -> Optional[int]:
64-
"""Get the current processed history event sequence."""
65-
...
66-
6757
@property
6858
def is_suspended(self) -> bool:
6959
"""Get whether this orchestration is currently suspended."""
@@ -132,8 +122,6 @@ def ensure_compatibility(context_class: type) -> type:
132122
"current_utc_datetime",
133123
"is_replaying",
134124
"workflow_name",
135-
"parent_instance_id",
136-
"history_event_sequence",
137125
"is_suspended",
138126
]
139127

durabletask/aio/context.py

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from __future__ import annotations
2121

2222
import os
23-
from dataclasses import dataclass
2423
from datetime import datetime, timedelta
2524
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast
2625

@@ -45,24 +44,6 @@
4544
T = TypeVar("T")
4645

4746

48-
@dataclass(frozen=True)
49-
class WorkflowInfo:
50-
"""
51-
Read-only metadata snapshot about the running workflow execution.
52-
53-
Similar to Temporal's workflow.info, this provides convenient access to
54-
workflow execution metadata in a single immutable object.
55-
"""
56-
57-
instance_id: str
58-
workflow_name: Optional[str]
59-
is_replaying: bool
60-
is_suspended: bool
61-
parent_instance_id: Optional[str]
62-
current_time: datetime
63-
history_event_sequence: int
64-
65-
6647
@ensure_compatibility
6748
class AsyncWorkflowContext(DeterministicContextMixin):
6849
"""
@@ -130,52 +111,15 @@ def is_replaying(self) -> bool:
130111
@property
131112
def is_suspended(self) -> bool:
132113
"""Check if the workflow is currently suspended."""
133-
return getattr(self._base_ctx, "is_suspended", False)
114+
return self._base_ctx.is_suspended
134115

135116
@property
136117
def workflow_name(self) -> Optional[str]:
137118
"""Get the workflow name."""
138119
return getattr(self._base_ctx, "workflow_name", None)
139120

140-
@property
141-
def parent_instance_id(self) -> Optional[str]:
142-
"""Get the parent instance ID (for sub-orchestrators)."""
143-
return getattr(self._base_ctx, "parent_instance_id", None)
144-
145-
@property
146-
def history_event_sequence(self) -> int:
147-
"""Get the current history event sequence number."""
148-
return getattr(self._base_ctx, "history_event_sequence", 0)
149-
150-
@property
151-
def execution_info(self) -> Optional[Any]:
152-
"""Get execution_info from the base context if available, else None."""
153-
return getattr(self._base_ctx, "execution_info", None)
154-
155-
@property
156-
def info(self) -> WorkflowInfo:
157-
"""
158-
Get a read-only snapshot of workflow execution metadata.
159-
160-
This provides a Temporal-style info object bundling instance_id, workflow_name,
161-
is_replaying, timestamps, and other metadata in a single immutable object.
162-
Useful for deterministic logging, idempotency keys, and conditional logic based on replay state.
163-
164-
Returns:
165-
WorkflowInfo: Immutable dataclass with workflow execution metadata
166-
"""
167-
return WorkflowInfo(
168-
instance_id=self.instance_id,
169-
workflow_name=self.workflow_name,
170-
is_replaying=self.is_replaying,
171-
is_suspended=self.is_suspended,
172-
parent_instance_id=self.parent_instance_id,
173-
current_time=self.current_utc_datetime,
174-
history_event_sequence=self.history_event_sequence,
175-
)
176-
177121
# Activity operations
178-
def activity(
122+
def call_activity(
179123
self,
180124
activity_fn: Union[dt_task.Activity[Any, Any], str],
181125
*,
@@ -206,24 +150,6 @@ def activity(
206150
metadata=metadata,
207151
)
208152

209-
def call_activity(
210-
self,
211-
activity_fn: Union[dt_task.Activity[Any, Any], str],
212-
*,
213-
input: Any = None,
214-
retry_policy: Any = None,
215-
app_id: Optional[str] = None,
216-
metadata: Optional[Dict[str, str]] = None,
217-
) -> ActivityAwaitable[Any]:
218-
"""Alias for activity() method for API compatibility."""
219-
return self.activity(
220-
activity_fn,
221-
input=input,
222-
retry_policy=retry_policy,
223-
app_id=app_id,
224-
metadata=metadata,
225-
)
226-
227153
# Sub-orchestrator operations
228154
def sub_orchestrator(
229155
self,

0 commit comments

Comments
 (0)