Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/SERVICE_RESTARTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ restart safe:
epoch index are persisted to Redis on every transition, along with per-workflow
epoch ordering and frontier. On startup the server rebuilds the full task DAG,
ready queue, and epoch frontiers from these records (`TaskRuntime.rehydrate`).
A transition's task records, workflow status-set membership, and schedule
snapshot are written as a single atomic Redis transaction
(`WorkflowRegistry.commit_transition`), so a crash mid-persist commits the whole
transition or none of it. Event-driven transitions are additionally healed by replay;
the API-driven workflow cancel relies on this atomicity alone.
- **Replayable task events.** Task lifecycle events flow through a durable Redis
stream consumed from a persisted cursor. The ordering is what makes replay
safe: a transition is written to durable scheduler state *before* its event is
Expand Down
161 changes: 39 additions & 122 deletions src/server/registries/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,106 +259,51 @@ async def get_workflow_async(self, workflow_id: str) -> Workflow | None:
remaining_tasks,
)

def update_workflow(self, workflow_id: str, **kwargs: Any) -> None:
mapping = _workflow_update(kwargs)
self._rds.sync.hash_set(workflow_key(workflow_id), mapping=mapping)

async def update_workflow_async(self, workflow_id: str, **kwargs: Any) -> None:
mapping = _workflow_update(kwargs)
await self._rds.asyncio.hash_set(workflow_key(workflow_id), mapping=mapping)

def mark_task_dispatched(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
with self._rds.sync.control_pipeline() as pipe:
pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
pipe.execute()

async def mark_task_dispatched_async(
self, workflow_id: str, *task_ids: str
def commit_transition(
self,
workflow_id: str,
*,
records: Sequence[PersistedTask] = (),
dispatched: Sequence[str] = (),
pending: Sequence[str] = (),
done: Sequence[str] = (),
failed: Sequence[str] = (),
cancelled: Sequence[str] = (),
sched: WorkflowSched | None = None,
) -> None:
mapping = _workflow_update()
async with self._rds.asyncio.control_pipeline() as pipe:
pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
await pipe.execute()

def mark_task_done(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
with self._rds.sync.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
pipe.execute()

async def mark_task_done_async(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
async with self._rds.asyncio.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
await pipe.execute()

def mark_task_pending(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
"""Apply a workflow state delta as one atomic control-Redis transaction.

``records`` are upserted; ``dispatched`` / ``pending`` / ``done`` /
``failed`` / ``cancelled`` move their task ids into the matching status-set
membership; ``sched`` snapshots the schedule when present. The records,
membership moves, the workflow's ``updated_at``, and the schedule snapshot
commit together or not at all, so a crash mid-persist can never leave
durable state half-applied.
"""
terminal = (*done, *failed, *cancelled)
touched_membership = bool(dispatched or pending or terminal)
with self._rds.sync.control_pipeline() as pipe:
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
pipe.execute()

async def mark_task_pending_async(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
async with self._rds.asyncio.control_pipeline() as pipe:
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
await pipe.execute()

def mark_task_failed(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
with self._rds.sync.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.sadd(workflow_failed_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
pipe.execute()

async def mark_task_failed_async(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
async with self._rds.asyncio.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.sadd(workflow_failed_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
await pipe.execute()

def mark_task_cancelled(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
with self._rds.sync.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
for item in records:
pipe.set(task_state_key(item.record.task_id), item.model_dump_json())
if dispatched:
pipe.sadd(workflow_dispatched_tasks_key(workflow_id), *dispatched)
if pending:
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *pending)
if terminal:
pipe.srem(workflow_tasks_key(workflow_id), *terminal)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *terminal)
if failed:
pipe.sadd(workflow_failed_tasks_key(workflow_id), *failed)
if cancelled:
pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *cancelled)
if touched_membership or sched is not None:
pipe.hset(workflow_key(workflow_id), mapping=_workflow_update())
if sched is not None:
pipe.set(workflow_sched_key(workflow_id), sched.model_dump_json())
pipe.execute()

async def mark_task_cancelled_async(self, workflow_id: str, *task_ids: str) -> None:
mapping = _workflow_update()
async with self._rds.asyncio.control_pipeline() as pipe:
pipe.srem(workflow_tasks_key(workflow_id), *task_ids)
pipe.srem(workflow_dispatched_tasks_key(workflow_id), *task_ids)
pipe.sadd(workflow_cancelled_tasks_key(workflow_id), *task_ids)
pipe.hset(workflow_key(workflow_id), mapping=mapping)
await pipe.execute()

# ---- Durable task state (for restart rehydration) ----------------- #

def save_task_states(self, items: Sequence[PersistedTask]) -> None:
if not items:
return
with self._rds.sync.control_pipeline() as pipe:
for item in items:
pipe.set(task_state_key(item.record.task_id), item.model_dump_json())
pipe.execute()

async def save_task_states_async(self, items: Sequence[PersistedTask]) -> None:
if not items:
return
Expand Down Expand Up @@ -387,30 +332,6 @@ async def load_task_states_async(
PersistedTask.model_validate_json(blob) if blob else None for blob in blobs
]

def delete_task_states(self, *task_ids: str) -> None:
if not task_ids:
return
with self._rds.sync.control_pipeline() as pipe:
for task_id in task_ids:
pipe.delete(task_state_key(task_id))
pipe.execute()

async def delete_task_states_async(self, *task_ids: str) -> None:
if not task_ids:
return
async with self._rds.asyncio.control_pipeline() as pipe:
for task_id in task_ids:
pipe.delete(task_state_key(task_id))
await pipe.execute()

def save_workflow_sched(
self, workflow_id: str, in_epoch_order: bool, epoch_frontier: int
) -> None:
payload = WorkflowSched(
in_epoch_order=in_epoch_order, epoch_frontier=epoch_frontier
).model_dump_json()
self._rds.sync.set_value(workflow_sched_key(workflow_id), payload)

async def save_workflow_sched_async(
self, workflow_id: str, in_epoch_order: bool, epoch_frontier: int
) -> None:
Expand All @@ -419,10 +340,6 @@ async def save_workflow_sched_async(
).model_dump_json()
await self._rds.asyncio.set_value(workflow_sched_key(workflow_id), payload)

def load_workflow_sched(self, workflow_id: str) -> WorkflowSched | None:
blob = self._rds.sync.get(workflow_sched_key(workflow_id))
return WorkflowSched.model_validate_json(blob) if blob else None

async def load_workflow_sched_async(self, workflow_id: str) -> WorkflowSched | None:
blob = await self._rds.asyncio.get(workflow_sched_key(workflow_id))
return WorkflowSched.model_validate_json(blob) if blob else None
Expand Down
Loading
Loading