diff --git a/docs/SERVICE_RESTARTS.md b/docs/SERVICE_RESTARTS.md index b2def81..07552e7 100644 --- a/docs/SERVICE_RESTARTS.md +++ b/docs/SERVICE_RESTARTS.md @@ -51,6 +51,11 @@ restart safe: epoch index are persisted to Redis on every transition, along with per-workflow epoch ordering and frontier. On startup the server rebuilds the full task DAG, ready queue, and epoch frontiers from these records (`TaskRuntime.rehydrate`). + A transition's task records, workflow status-set membership, and schedule + snapshot are written as a single atomic Redis transaction + (`WorkflowRegistry.commit_transition`), so a crash mid-persist commits the whole + transition or none of it. Event-driven transitions are additionally healed by replay; + the API-driven workflow cancel relies on this atomicity alone. - **Replayable task events.** Task lifecycle events flow through a durable Redis stream consumed from a persisted cursor. The ordering is what makes replay safe: a transition is written to durable scheduler state *before* its event is diff --git a/src/server/registries/workflow.py b/src/server/registries/workflow.py index 9477bf6..200b613 100644 --- a/src/server/registries/workflow.py +++ b/src/server/registries/workflow.py @@ -259,106 +259,51 @@ async def get_workflow_async(self, workflow_id: str) -> Workflow | None: remaining_tasks, ) - def update_workflow(self, workflow_id: str, **kwargs: Any) -> None: - mapping = _workflow_update(kwargs) - self._rds.sync.hash_set(workflow_key(workflow_id), mapping=mapping) - - async def update_workflow_async(self, workflow_id: str, **kwargs: Any) -> None: - mapping = _workflow_update(kwargs) - await self._rds.asyncio.hash_set(workflow_key(workflow_id), mapping=mapping) - - def mark_task_dispatched(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - with self._rds.sync.control_pipeline() as pipe: - pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - pipe.execute() - - async def mark_task_dispatched_async( - self, workflow_id: str, *task_ids: str + def commit_transition( + self, + workflow_id: str, + *, + records: Sequence[PersistedTask] = (), + dispatched: Sequence[str] = (), + pending: Sequence[str] = (), + done: Sequence[str] = (), + failed: Sequence[str] = (), + cancelled: Sequence[str] = (), + sched: WorkflowSched | None = None, ) -> None: - mapping = _workflow_update() - async with self._rds.asyncio.control_pipeline() as pipe: - pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - await pipe.execute() - - def mark_task_done(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - with self._rds.sync.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - pipe.execute() - - async def mark_task_done_async(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - async with self._rds.asyncio.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - await pipe.execute() - - def mark_task_pending(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() + """Apply a workflow state delta as one atomic control-Redis transaction. + + ``records`` are upserted; ``dispatched`` / ``pending`` / ``done`` / + ``failed`` / ``cancelled`` move their task ids into the matching status-set + membership; ``sched`` snapshots the schedule when present. The records, + membership moves, the workflow's ``updated_at``, and the schedule snapshot + commit together or not at all, so a crash mid-persist can never leave + durable state half-applied. + """ + terminal = (*done, *failed, *cancelled) + touched_membership = bool(dispatched or pending or terminal) with self._rds.sync.control_pipeline() as pipe: - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - pipe.execute() - - async def mark_task_pending_async(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - async with self._rds.asyncio.control_pipeline() as pipe: - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - await pipe.execute() - - def mark_task_failed(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - with self._rds.sync.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.sadd(workflow_failed_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - pipe.execute() - - async def mark_task_failed_async(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - async with self._rds.asyncio.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.sadd(workflow_failed_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - await pipe.execute() - - def mark_task_cancelled(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - with self._rds.sync.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) + for item in records: + pipe.set(task_state_key(item.record.task_id), item.model_dump_json()) + if dispatched: + pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *dispatched) + if pending: + pipe.srem(workflow_dispatched_tasks_key(workflow_id), *pending) + if terminal: + pipe.srem(workflow_tasks_key(workflow_id), *terminal) + pipe.srem(workflow_dispatched_tasks_key(workflow_id), *terminal) + if failed: + pipe.sadd(workflow_failed_tasks_key(workflow_id), *failed) + if cancelled: + pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *cancelled) + if touched_membership or sched is not None: + pipe.hset(workflow_key(workflow_id), mapping=_workflow_update()) + if sched is not None: + pipe.set(workflow_sched_key(workflow_id), sched.model_dump_json()) pipe.execute() - async def mark_task_cancelled_async(self, workflow_id: str, *task_ids: str) -> None: - mapping = _workflow_update() - async with self._rds.asyncio.control_pipeline() as pipe: - pipe.srem(workflow_tasks_key(workflow_id), *task_ids) - pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids) - pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *task_ids) - pipe.hset(workflow_key(workflow_id), mapping=mapping) - await pipe.execute() - # ---- Durable task state (for restart rehydration) ----------------- # - def save_task_states(self, items: Sequence[PersistedTask]) -> None: - if not items: - return - with self._rds.sync.control_pipeline() as pipe: - for item in items: - pipe.set(task_state_key(item.record.task_id), item.model_dump_json()) - pipe.execute() - async def save_task_states_async(self, items: Sequence[PersistedTask]) -> None: if not items: return @@ -387,30 +332,6 @@ async def load_task_states_async( PersistedTask.model_validate_json(blob) if blob else None for blob in blobs ] - def delete_task_states(self, *task_ids: str) -> None: - if not task_ids: - return - with self._rds.sync.control_pipeline() as pipe: - for task_id in task_ids: - pipe.delete(task_state_key(task_id)) - pipe.execute() - - async def delete_task_states_async(self, *task_ids: str) -> None: - if not task_ids: - return - async with self._rds.asyncio.control_pipeline() as pipe: - for task_id in task_ids: - pipe.delete(task_state_key(task_id)) - await pipe.execute() - - def save_workflow_sched( - self, workflow_id: str, in_epoch_order: bool, epoch_frontier: int - ) -> None: - payload = WorkflowSched( - in_epoch_order=in_epoch_order, epoch_frontier=epoch_frontier - ).model_dump_json() - self._rds.sync.set_value(workflow_sched_key(workflow_id), payload) - async def save_workflow_sched_async( self, workflow_id: str, in_epoch_order: bool, epoch_frontier: int ) -> None: @@ -419,10 +340,6 @@ async def save_workflow_sched_async( ).model_dump_json() await self._rds.asyncio.set_value(workflow_sched_key(workflow_id), payload) - def load_workflow_sched(self, workflow_id: str) -> WorkflowSched | None: - blob = self._rds.sync.get(workflow_sched_key(workflow_id)) - return WorkflowSched.model_validate_json(blob) if blob else None - async def load_workflow_sched_async(self, workflow_id: str) -> WorkflowSched | None: blob = await self._rds.asyncio.get(workflow_sched_key(workflow_id)) return WorkflowSched.model_validate_json(blob) if blob else None diff --git a/src/server/task/runtime.py b/src/server/task/runtime.py index 209aec5..1547c8c 100644 --- a/src/server/task/runtime.py +++ b/src/server/task/runtime.py @@ -371,62 +371,80 @@ def _persisted_task_locked(self, task_id: str) -> PersistedTask | None: epoch_index=self._task_epoch_index.get(task_id), ) - def _persist_locked(self, *task_ids: str) -> None: - items = [ + def _records_locked(self, *task_ids: str) -> list[PersistedTask]: + return [ persisted for task_id in dict.fromkeys(task_ids) if (persisted := self._persisted_task_locked(task_id)) ] - if items: - self._workflow_registry.save_task_states(items) - - def _persist_sched_locked(self, workflow_id: str) -> None: - self._workflow_registry.save_workflow_sched( - workflow_id, - self._workflow_in_epoch_order.get(workflow_id, False), - self._workflow_epoch_frontier.get(workflow_id, 0), + + def _sched_locked(self, workflow_id: str) -> WorkflowSched: + return WorkflowSched( + in_epoch_order=self._workflow_in_epoch_order.get(workflow_id, False), + epoch_frontier=self._workflow_epoch_frontier.get(workflow_id, 0), ) - def _persist_terminal_locked(self, *task_ids: str) -> None: - """Persist each task's final state — its record and its done/failed/cancelled - set membership (by current status) — as the single last step of a transition. + def _persist_locked(self, *task_ids: str) -> None: + """Commit task records (no membership change) atomically, per workflow.""" + by_workflow: dict[str, list[str]] = defaultdict(list) + for task_id in dict.fromkeys(task_ids): + if record := self._tasks.get(task_id): + by_workflow[record.workflow_id].append(task_id) + for workflow_id, ids in by_workflow.items(): + self._workflow_registry.commit_transition( + workflow_id, records=self._records_locked(*ids) + ) - Persisting only after all in-memory mutations means a failed write can't leave - the in-memory state half-applied; the error propagates and the at-least-once - replay re-persists via ``_repersist_terminal_workflow_locked``. This assumes - the in-memory mutations never raise, which holds while ordered tasks carry - ``position_in_epoch`` (so the ready-queue helpers never hit their guards). + def _persist_terminal_locked(self, *task_ids: str, sched: bool = True) -> None: + """Commit each task's final state — its record and its done/failed/cancelled + set membership (by current status) — and the workflow schedule, as one atomic + transaction per workflow and the single last step of a transition. + + Committing only after all in-memory mutations means a failed or crashed write + can't leave durable state half-applied: the transaction commits in full or not + at all. Event-driven callers additionally heal via the at-least-once replay + (``_repersist_terminal_workflow_locked``); the API-driven cancel relies on this + atomicity alone. Assumes the in-memory mutations never raise, which holds while + ordered tasks carry ``position_in_epoch`` (so the ready-queue helpers never hit + their guards). """ - workflow_terminal_tasks: dict[str, tuple[list[str], list[str], list[str]]] = ( - defaultdict(lambda: ([], [], [])) + moves: dict[str, tuple[list[str], list[str], list[str]]] = defaultdict( + lambda: ([], [], []) ) for task_id in dict.fromkeys(task_ids): record = self._tasks.get(task_id) if record is None: continue - workflow_id = record.workflow_id match record.status: case TaskStatus.DONE: - workflow_terminal_tasks[workflow_id][0].append(task_id) + moves[record.workflow_id][0].append(task_id) case TaskStatus.FAILED: - workflow_terminal_tasks[workflow_id][1].append(task_id) + moves[record.workflow_id][1].append(task_id) case TaskStatus.CANCELLED: - workflow_terminal_tasks[workflow_id][2].append(task_id) - for workflow_id, (done, failed, cancelled) in workflow_terminal_tasks.items(): - if done: - self._workflow_registry.mark_task_done(workflow_id, *done) - if failed: - self._workflow_registry.mark_task_failed(workflow_id, *failed) - if cancelled: - self._workflow_registry.mark_task_cancelled(workflow_id, *cancelled) - self._persist_locked(*task_ids) + moves[record.workflow_id][2].append(task_id) + case _: + self._logger.warning( + "Non-terminal task %s (%s) skipped in terminal persist", + task_id, + record.status, + ) + for workflow_id, (done, failed, cancelled) in moves.items(): + if ids := done + failed + cancelled: + self._workflow_registry.commit_transition( + workflow_id, + records=self._records_locked(*ids), + done=done, + failed=failed, + cancelled=cancelled, + sched=self._sched_locked(workflow_id) if sched else None, + ) def _repersist_terminal_workflow_locked(self, workflow_id: str) -> None: - """Re-persist the workflow's already-terminal tasks and schedule state. + """Re-commit the workflow's already-terminal tasks and schedule state. The idempotency guard calls this on a replayed terminal event: the original transition may have failed its persist after committing in memory, so - re-persisting makes the durable state current before the consumer's cursor + re-committing makes the durable state current before the consumer's cursor advances past the event (else the task re-runs after a restart). It covers the whole workflow, not just the replayed task, because a cascade's other affected tasks aren't identifiable here. Idempotent; only on a rare duplicate replay. @@ -439,7 +457,10 @@ def _repersist_terminal_workflow_locked(self, workflow_id: str) -> None: ] if terminal_ids: self._persist_terminal_locked(*terminal_ids) - self._persist_sched_locked(workflow_id) + else: + self._workflow_registry.commit_transition( + workflow_id, sched=self._sched_locked(workflow_id) + ) # ------------------------------------------------------------------ # # Ready queue helpers @@ -658,7 +679,6 @@ def mark_pending(self, task_id: str, *, increment_retry: bool = False) -> None: record.started_ts = None record.finished_ts = None record.error = None - self._workflow_registry.mark_task_pending(record.workflow_id, task_id) if increment_retry: try: if record.max_attempts is not None and record.max_attempts >= 0: @@ -667,7 +687,11 @@ def mark_pending(self, task_id: str, *, increment_retry: bool = False) -> None: record.attempts = record.attempts + 1 except Exception: record.attempts = (record.attempts or 0) + 1 - self._persist_locked(task_id) + self._workflow_registry.commit_transition( + record.workflow_id, + records=self._records_locked(task_id), + pending=[task_id], + ) def requeue(self, task_id: str, *, front: bool = False) -> bool: """Reinsert a task into the ready queue.""" @@ -738,8 +762,11 @@ def _plan_merge_locked( sibling_record.merged_parent_id = task_id sibling_record.assigned_worker = None sibling_record.merge_slice = None - self._workflow_registry.mark_task_dispatched(record.workflow_id, *siblings) - self._persist_locked(task_id, *siblings) + self._workflow_registry.commit_transition( + record.workflow_id, + records=self._records_locked(task_id, *siblings), + dispatched=siblings, + ) return siblings @@ -891,10 +918,13 @@ def mark_dispatched(self, task_id: str, worker: Worker) -> None: record.dispatched_ts = time.time() record.next_retry_at = None record.supplier_id = supplier_id - self._workflow_registry.mark_task_dispatched(record.workflow_id, task_id) self._remove_from_ready_locked(task_id) self._merge_bucket_remove(task_id) - self._persist_locked(task_id) + self._workflow_registry.commit_transition( + record.workflow_id, + records=self._records_locked(task_id), + dispatched=[task_id], + ) def mark_started( self, @@ -915,8 +945,11 @@ def mark_started( record.started_ts = started_ts if worker_id: record.assigned_worker = worker_id - self._workflow_registry.mark_task_dispatched(record.workflow_id, task_id) - self._persist_locked(task_id) + self._workflow_registry.commit_transition( + record.workflow_id, + records=self._records_locked(task_id), + dispatched=[task_id], + ) def mark_updated(self, task_id: str, payload: dict[str, Any]) -> None: with self._lock: @@ -1011,8 +1044,6 @@ def mark_succeeded( ) self._persist_terminal_locked(task_id, *merged_children_ids) - if record is not None: - self._persist_sched_locked(record.workflow_id) if ready_children: self._cv.notify_all() @@ -1121,8 +1152,6 @@ def mark_failed( self._persist_terminal_locked( task_id, *merged_children_ids, *(dep_id for dep_id, _ in impacted) ) - if record is not None: - self._persist_sched_locked(record.workflow_id) return impacted, merged_children_ids, usages @@ -1141,6 +1170,8 @@ def cancel_workflow(self, workflow_id: str, reason: str = "cancelled") -> list[s for item in self._tasks.items() if item[1].workflow_id == workflow_id ] + if not workflow_tasks: + return touched # Unknown workflow: no records to move for task_id, record in workflow_tasks: match record.status: case TaskStatus.PENDING if not self._parent_is_active(task_id): @@ -1184,9 +1215,12 @@ def cancel_workflow(self, workflow_id: str, reason: str = "cancelled") -> list[s self._workflow_in_epoch_order.pop(workflow_id, None) for task_id, _ in workflow_tasks: self._task_epoch_index.pop(task_id, None) - self._persist_locked(*cancelling) - self._persist_terminal_locked(*cancelled) - self._persist_sched_locked(workflow_id) + self._workflow_registry.commit_transition( + workflow_id, + records=self._records_locked(*touched), + cancelled=cancelled, + sched=self._sched_locked(workflow_id), + ) for interrupt in interrupts: worker = self._worker_registry.get_worker(interrupt.worker_id) @@ -1244,7 +1278,7 @@ def mark_cancelled( self._merge_key_by_task.pop(task_id, None) self._merge_children_map.pop(task_id, None) record.assigned_worker = None - self._persist_terminal_locked(task_id) + self._persist_terminal_locked(task_id, sched=False) return usages def get_record(self, task_id: str) -> TaskRecord | None: diff --git a/tests/server/dispatcher/helpers.py b/tests/server/dispatcher/helpers.py index 67b96ad..f771f40 100644 --- a/tests/server/dispatcher/helpers.py +++ b/tests/server/dispatcher/helpers.py @@ -2,12 +2,14 @@ import logging import tempfile +from collections.abc import Sequence from pathlib import Path from types import SimpleNamespace from typing import Any from unittest import mock from server.dispatcher import Dispatcher +from server.registries.workflow import PersistedTask, WorkflowSched class CapturingDispatcher(Dispatcher): @@ -31,16 +33,19 @@ class WorkflowRegistryStub: async def register_workflow_async(self, workflow_id: str, tasks: list[Any]) -> None: return None - def mark_task_dispatched(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_done(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_failed(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_pending(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_cancelled(self, workflow_id: str, *task_ids: str) -> None: ... - def save_task_states(self, items: Any) -> None: ... - async def save_task_states_async(self, items: Any) -> None: ... - def save_workflow_sched( - self, wid: str, in_epoch_order: bool, frontier: int + def commit_transition( + self, + workflow_id: str, + *, + records: Sequence[PersistedTask] = (), + dispatched: Sequence[str] = (), + pending: Sequence[str] = (), + done: Sequence[str] = (), + failed: Sequence[str] = (), + cancelled: Sequence[str] = (), + sched: WorkflowSched | None = None, ) -> None: ... + async def save_task_states_async(self, items: Any) -> None: ... async def save_workflow_sched_async( self, wid: str, in_epoch_order: bool, frontier: int ) -> None: ... diff --git a/tests/server/task/test_runtime_epoch_order.py b/tests/server/task/test_runtime_epoch_order.py index 9ffe527..9f13c69 100644 --- a/tests/server/task/test_runtime_epoch_order.py +++ b/tests/server/task/test_runtime_epoch_order.py @@ -3,8 +3,10 @@ import asyncio import logging import threading +from collections.abc import Sequence from typing import Any, cast +from server.registries.workflow import PersistedTask, WorkflowSched from server.task.models import TaskStatus from server.task.runtime import TaskRuntime @@ -13,32 +15,23 @@ class _WorkflowRegistryStub: async def register_workflow_async(self, workflow_id: str, tasks: list[Any]) -> None: return None - def mark_task_dispatched(self, workflow_id: str, *task_ids: str) -> None: - return None - - def mark_task_done(self, workflow_id: str, *task_ids: str) -> None: - return None - - def mark_task_failed(self, workflow_id: str, *task_ids: str) -> None: - return None - - def mark_task_pending(self, workflow_id: str, *task_ids: str) -> None: - return None - - def mark_task_cancelled(self, workflow_id: str, *task_ids: str) -> None: - return None - - def save_task_states(self, items: Any) -> None: + def commit_transition( + self, + workflow_id: str, + *, + records: Sequence[PersistedTask] = (), + dispatched: Sequence[str] = (), + pending: Sequence[str] = (), + done: Sequence[str] = (), + failed: Sequence[str] = (), + cancelled: Sequence[str] = (), + sched: WorkflowSched | None = None, + ) -> None: return None async def save_task_states_async(self, items: Any) -> None: return None - def save_workflow_sched( - self, workflow_id: str, in_epoch_order: bool, frontier: int - ) -> None: - return None - async def save_workflow_sched_async( self, workflow_id: str, in_epoch_order: bool, frontier: int ) -> None: diff --git a/tests/server/task/test_runtime_rehydrate.py b/tests/server/task/test_runtime_rehydrate.py index 2ffb2b2..f82e962 100644 --- a/tests/server/task/test_runtime_rehydrate.py +++ b/tests/server/task/test_runtime_rehydrate.py @@ -2,6 +2,7 @@ import logging import threading +from collections.abc import Sequence from types import SimpleNamespace from typing import Any, cast @@ -74,11 +75,22 @@ def load_workflow_sched(self, workflow_id: str) -> WorkflowSched | None: async def load_workflow_sched_async(self, workflow_id: str) -> WorkflowSched | None: return self.load_workflow_sched(workflow_id) - def mark_task_dispatched(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_done(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_failed(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_pending(self, workflow_id: str, *task_ids: str) -> None: ... - def mark_task_cancelled(self, workflow_id: str, *task_ids: str) -> None: ... + def commit_transition( + self, + workflow_id: str, + *, + records: Sequence[PersistedTask] = (), + dispatched: Sequence[str] = (), + pending: Sequence[str] = (), + done: Sequence[str] = (), + failed: Sequence[str] = (), + cancelled: Sequence[str] = (), + sched: WorkflowSched | None = None, + ) -> None: + for item in records: + self.task_blobs[item.record.task_id] = item.model_dump_json() + if sched is not None: + self.sched[workflow_id] = sched.model_dump_json() class _WorkerRegistryStub: @@ -343,10 +355,10 @@ async def test_mark_succeeded_applies_in_memory_atomically_when_persist_raises( _, ids = await _register(runtime, GRAPH) a, b = ids["a"], ids["b"] - def boom(workflow_id: str, *task_ids: str) -> None: + def boom(*args: Any, **kwargs: Any) -> None: raise RuntimeError("redis down") - monkeypatch.setattr(registry, "mark_task_done", boom) + monkeypatch.setattr(registry, "commit_transition", boom) with pytest.raises(RuntimeError): runtime.mark_succeeded(a, "wkr-1", {}, "2026-06-01T00:00:00Z") @@ -357,7 +369,7 @@ def boom(workflow_id: str, *task_ids: str) -> None: assert b in runtime._ready_index # The at-least-once replay re-runs and is a no-op via the idempotency guard. - monkeypatch.setattr(registry, "mark_task_done", lambda *args, **kwargs: None) + monkeypatch.setattr(registry, "commit_transition", lambda *args, **kwargs: None) assert runtime.mark_succeeded(a, "wkr-1", {}, "2026-06-01T00:00:00Z") == [] @@ -370,10 +382,10 @@ async def test_mark_failed_applies_cascade_atomically_when_persist_raises( _, ids = await _register(runtime, GRAPH) a, b = ids["a"], ids["b"] - def boom(workflow_id: str, *task_ids: str) -> None: + def boom(*args: Any, **kwargs: Any) -> None: raise RuntimeError("redis down") - monkeypatch.setattr(registry, "mark_task_failed", boom) + monkeypatch.setattr(registry, "commit_transition", boom) with pytest.raises(RuntimeError): runtime.mark_failed(a, "wkr-1", {}, "2026-06-01T00:00:00Z") @@ -385,7 +397,7 @@ def boom(workflow_id: str, *task_ids: str) -> None: assert record_b is not None and record_b.status == TaskStatus.FAILED # Replay is a no-op via the idempotency guard (task already terminal). - monkeypatch.setattr(registry, "mark_task_failed", lambda *args, **kwargs: None) + monkeypatch.setattr(registry, "commit_transition", lambda *args, **kwargs: None) impacted, _, _ = runtime.mark_failed(a, "wkr-1", {}, "2026-06-01T00:00:00Z") assert impacted == [] @@ -399,16 +411,16 @@ async def test_replayed_terminal_event_repersists_after_failed_write( _, ids = await _register(runtime, GRAPH) a, b = ids["a"], ids["b"] - real_save = registry.save_task_states + real_commit = registry.commit_transition calls = {"n": 0} - def flaky_save(items: Any) -> None: + def flaky_commit(*args: Any, **kwargs: Any) -> None: calls["n"] += 1 if calls["n"] == 1: raise RuntimeError("redis down") - real_save(items) + real_commit(*args, **kwargs) - monkeypatch.setattr(registry, "save_task_states", flaky_save) + monkeypatch.setattr(registry, "commit_transition", flaky_commit) def persisted_status(task_id: str) -> str: return PersistedTask.model_validate_json( @@ -442,10 +454,10 @@ async def test_mark_cancelled_applies_in_memory_atomically_when_persist_raises( _, ids = await _register(runtime, GRAPH) a = ids["a"] - def boom(workflow_id: str, *task_ids: str) -> None: + def boom(*args: Any, **kwargs: Any) -> None: raise RuntimeError("redis down") - monkeypatch.setattr(registry, "mark_task_cancelled", boom) + monkeypatch.setattr(registry, "commit_transition", boom) with pytest.raises(RuntimeError): runtime.mark_cancelled(a, "wkr-1", {}, "2026-06-01T00:00:00Z") @@ -454,7 +466,7 @@ def boom(workflow_id: str, *task_ids: str) -> None: assert record_a is not None and record_a.status == TaskStatus.CANCELLED # Replay is a no-op via the idempotency guard (task already cancelled). - monkeypatch.setattr(registry, "mark_task_cancelled", lambda *args, **kwargs: None) + monkeypatch.setattr(registry, "commit_transition", lambda *args, **kwargs: None) runtime.mark_cancelled(a, "wkr-1", {}, "2026-06-01T00:00:00Z") record_a = runtime.get_record(a) assert record_a is not None and record_a.status == TaskStatus.CANCELLED @@ -469,16 +481,16 @@ async def test_mark_cancelled_repersists_on_replay_after_failed_write( _, ids = await _register(runtime, GRAPH) a = ids["a"] - real_save = registry.save_task_states + real_commit = registry.commit_transition calls = {"n": 0} - def flaky_save(items: Any) -> None: + def flaky_commit(*args: Any, **kwargs: Any) -> None: calls["n"] += 1 if calls["n"] == 1: raise RuntimeError("redis down") - real_save(items) + real_commit(*args, **kwargs) - monkeypatch.setattr(registry, "save_task_states", flaky_save) + monkeypatch.setattr(registry, "commit_transition", flaky_commit) def persisted_status(task_id: str) -> str: return PersistedTask.model_validate_json( @@ -493,3 +505,63 @@ def persisted_status(task_id: str) -> str: # Replay of the same cancellation: the guard heals by re-persisting. runtime.mark_cancelled(a, "wkr-1", {}, "2026-06-01T00:00:00Z") assert persisted_status(a) == TaskStatus.CANCELLED + + +@pytest.mark.anyio +async def test_cancel_workflow_commits_atomically_on_crash( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = FakeWorkflowRegistry() + runtime = _runtime(registry) + workflow_id, ids = await _register(runtime, GRAPH) + a, b = ids["a"], ids["b"] + + def boom(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("redis down") + + def persisted_status(task_id: str) -> str: + return PersistedTask.model_validate_json( + registry.task_blobs[task_id] + ).record.status + + monkeypatch.setattr(registry, "commit_transition", boom) + with pytest.raises(RuntimeError): + runtime.cancel_workflow(workflow_id) + + # Persist is the single last step, so the cancellation is fully applied in + # memory even though the commit failed. + record_a = runtime.get_record(a) + assert record_a is not None and record_a.status == TaskStatus.CANCELLED + + # The cancel has no event-replay backstop, but the atomic commit never ran, + # so durable state is untouched — nothing is half-cancelled. A fresh restart + # restores the pre-cancel workflow, which the operator can cancel again. + assert persisted_status(a) == TaskStatus.PENDING + assert persisted_status(b) == TaskStatus.PENDING + + restored = _runtime(registry) + assert await restored.rehydrate() == 1 + restored_a = restored.get_record(a) + restored_b = restored.get_record(b) + assert restored_a is not None and restored_a.status == TaskStatus.PENDING + assert restored_b is not None and restored_b.status == TaskStatus.PENDING + + +@pytest.mark.anyio +async def test_rehydrate_restores_cancelled_workflow() -> None: + registry = FakeWorkflowRegistry() + runtime = _runtime(registry) + workflow_id, ids = await _register(runtime, GRAPH) + a, b = ids["a"], ids["b"] + + runtime.cancel_workflow(workflow_id) + + restored = _runtime(registry) + await restored.rehydrate() + + # Cancelled tasks rehydrate terminal and are never re-enqueued. + record_a = restored.get_record(a) + record_b = restored.get_record(b) + assert record_a is not None and record_a.status == TaskStatus.CANCELLED + assert record_b is not None and record_b.status == TaskStatus.CANCELLED + assert restored.ready_queue_length() == 0