@@ -378,7 +378,7 @@ def __await__(self) -> Generator[Any, Any, List[TOutput]]:
378378class WhenAnyAwaitable (AwaitableBase [task .Task [Any ]]):
379379 """Awaitable for when_any operations (wait for any task to complete)."""
380380
381- __slots__ = ("_tasks_like" , )
381+ __slots__ = ("_originals" , "_underlying" )
382382
383383 def __init__ (self , tasks_like : Iterable [Union [AwaitableBase [Any ], task .Task [Any ]]]):
384384 """
@@ -388,33 +388,33 @@ def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]
388388 tasks_like: Iterable of awaitables or tasks to wait for
389389 """
390390 super ().__init__ ()
391- self ._tasks_like = list (tasks_like )
391+ self ._originals = list (tasks_like )
392+ # Defer conversion to avoid issues with incomplete mocks and coroutine reuse
393+ self ._underlying : Optional [List [task .Task [Any ]]] = None
394+
395+ def _ensure_underlying (self ) -> List [task .Task [Any ]]:
396+ """Lazily convert originals to tasks, caching the result."""
397+ if self ._underlying is None :
398+ self ._underlying = []
399+ for a in self ._originals :
400+ if isinstance (a , AwaitableBase ):
401+ self ._underlying .append (a ._to_task ())
402+ elif isinstance (a , task .Task ):
403+ self ._underlying .append (a )
404+ else :
405+ raise TypeError ("when_any expects AwaitableBase or durabletask.task.Task" )
406+ return self ._underlying
392407
393408 def _to_task (self ) -> task .Task [Any ]:
394409 """Convert to a when_any task."""
395- underlying : List [task .Task [Any ]] = []
396- for a in self ._tasks_like :
397- if isinstance (a , AwaitableBase ):
398- underlying .append (a ._to_task ())
399- elif isinstance (a , task .Task ):
400- underlying .append (a )
401- else :
402- raise TypeError ("when_any expects AwaitableBase or durabletask.task.Task" )
403- return cast (task .Task [Any ], task .when_any (underlying ))
410+ return cast (task .Task [Any ], task .when_any (self ._ensure_underlying ()))
404411
405412 def __await__ (self ) -> Generator [Any , Any , Any ]:
406413 """Return a proxy that compares equal to the original item and exposes get_result()."""
407- when_any_task = self ._to_task ()
414+ underlying = self ._ensure_underlying ()
415+ when_any_task = task .when_any (underlying )
408416 completed = yield when_any_task
409417
410- # Build underlying mapping original -> underlying task
411- underlying : List [task .Task [Any ]] = []
412- for a in self ._tasks_like :
413- if isinstance (a , AwaitableBase ):
414- underlying .append (a ._to_task ())
415- elif isinstance (a , task .Task ):
416- underlying .append (a )
417-
418418 class _CompletedProxy :
419419 __slots__ = ("_original" , "_completed" )
420420
@@ -431,20 +431,28 @@ def get_result(self) -> Any:
431431 return self ._completed .get_result ()
432432 return getattr (self ._completed , "result" , None )
433433
434+ @property
435+ def __dict__ (self ) -> dict [str , Any ]:
436+ """Expose a dict-like view for compatibility with user code."""
437+ return {
438+ "_original" : self ._original ,
439+ "_completed" : self ._completed ,
440+ }
441+
434442 def __repr__ (self ) -> str : # pragma: no cover
435443 return f"<WhenAnyCompleted proxy for { self ._original !r} >"
436444
437445 # If the runtime returned a non-task sentinel (e.g., tests), assume first item won
438446 if not isinstance (completed , task .Task ):
439- return _CompletedProxy (self ._tasks_like [0 ], completed )
447+ return _CompletedProxy (self ._originals [0 ], completed )
440448
441449 # Map completed task back to the original item and return proxy
442- for original , under in zip (self ._tasks_like , underlying , strict = False ):
450+ for original , under in zip (self ._originals , underlying , strict = False ):
443451 if completed == under :
444452 return _CompletedProxy (original , completed )
445453
446454 # Fallback proxy; treat the first as original
447- return _CompletedProxy (self ._tasks_like [0 ], completed )
455+ return _CompletedProxy (self ._originals [0 ], completed )
448456
449457
450458class WhenAnyResultAwaitable (AwaitableBase [tuple [int , Any ]]):
@@ -454,7 +462,7 @@ class WhenAnyResultAwaitable(AwaitableBase[tuple[int, Any]]):
454462 This is useful when you need to know which task completed first, not just its result.
455463 """
456464
457- __slots__ = ("_tasks_like " , "_awaitables " )
465+ __slots__ = ("_originals " , "_underlying " )
458466
459467 def __init__ (self , tasks_like : Iterable [Union [AwaitableBase [Any ], task .Task [Any ]]]):
460468 """
@@ -464,41 +472,37 @@ def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]
464472 tasks_like: Iterable of awaitables or tasks to wait for
465473 """
466474 super ().__init__ ()
467- self ._tasks_like = list (tasks_like )
468- self ._awaitables = self ._tasks_like # Alias for compatibility
475+ self ._originals = list (tasks_like )
476+ # Defer conversion to avoid issues with incomplete mocks and coroutine reuse
477+ self ._underlying : Optional [List [task .Task [Any ]]] = None
478+
479+ def _ensure_underlying (self ) -> List [task .Task [Any ]]:
480+ """Lazily convert originals to tasks, caching the result."""
481+ if self ._underlying is None :
482+ self ._underlying = []
483+ for a in self ._originals :
484+ if isinstance (a , AwaitableBase ):
485+ self ._underlying .append (a ._to_task ())
486+ elif isinstance (a , task .Task ):
487+ self ._underlying .append (a )
488+ else :
489+ raise TypeError (
490+ "when_any_with_result expects AwaitableBase or durabletask.task.Task"
491+ )
492+ return self ._underlying
469493
470494 def _to_task (self ) -> task .Task [Any ]:
471495 """Convert to a when_any task with result tracking."""
472- underlying : List [task .Task [Any ]] = []
473- for a in self ._tasks_like :
474- if isinstance (a , AwaitableBase ):
475- underlying .append (a ._to_task ())
476- elif isinstance (a , task .Task ):
477- underlying .append (a )
478- else :
479- raise TypeError (
480- "when_any_with_result expects AwaitableBase or durabletask.task.Task"
481- )
482-
483- # Use when_any and then determine which task completed
484- when_any_task = task .when_any (underlying )
485- return cast (task .Task [Any ], when_any_task )
496+ return cast (task .Task [Any ], task .when_any (self ._ensure_underlying ()))
486497
487498 def __await__ (self ) -> Generator [Any , Any , tuple [int , Any ]]:
488499 """Override to provide index + result tuple."""
489- t = self ._to_task ()
490- completed_task = yield t
491-
492- # Find which task completed by comparing results
493- underlying_tasks : List [task .Task [Any ]] = []
494- for a in self ._tasks_like :
495- if isinstance (a , AwaitableBase ):
496- underlying_tasks .append (a ._to_task ())
497- elif isinstance (a , task .Task ):
498- underlying_tasks .append (a )
500+ underlying = self ._ensure_underlying ()
501+ when_any_task = task .when_any (underlying )
502+ completed_task = yield when_any_task
499503
500504 # The completed_task should match one of our underlying tasks
501- for i , underlying_task in enumerate (underlying_tasks ):
505+ for i , underlying_task in enumerate (underlying ):
502506 if underlying_task == completed_task :
503507 return (i , completed_task .result if hasattr (completed_task , "result" ) else None )
504508
0 commit comments