Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/workflows/server/abstract_workflow_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod, ABC
from typing import Literal, Optional, List, Any
from dataclasses import dataclass
from pydantic import BaseModel
from pydantic import BaseModel, Field


Status = Literal["running", "completed", "failed", "cancelled"]
Expand All @@ -23,6 +23,7 @@ class PersistentHandler(BaseModel):
workflow_name: str
status: Status
ctx: dict[str, Any]
handler_metadata_json: str | None = Field(default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the nested json? Can't this just be a dict[str, Any] (and default to {})?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly ease of storage lol but yes, good point



class AbstractWorkflowStore(ABC):
Expand Down
185 changes: 171 additions & 14 deletions src/workflows/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class HandlerDict(TypedDict):
started_at: str
updated_at: str | None
completed_at: str | None
handler_metadata: dict[str, Any] | None


class WorkflowServer:
Expand All @@ -73,7 +74,6 @@ def __init__(
persistence_backoff: list[float] = [0.5, 3],
):
self._workflows: dict[str, Workflow] = {}
self._contexts: dict[str, Context] = {}
self._handlers: dict[str, _WorkflowHandler] = {}
self._results: dict[str, StopEvent] = {}
self._workflow_store = workflow_store
Expand Down Expand Up @@ -197,7 +197,7 @@ async def start(self) -> "WorkflowServer":
continue

self._run_workflow_handler(
persistent.handler_id, persistent.workflow_name, handler
persistent.handler_id, persistent.workflow_name, handler, None
)
return self

Expand Down Expand Up @@ -377,6 +377,9 @@ async def _run_workflow(self, request: Request) -> JSONResponse:
handler_id:
type: string
description: Workflow handler identifier to continue from a previous completed run.
handler_metadata:
type: object
description: Optional metadata to attach to the handler. If handler_id is provided, this metadata will overwrite existing metadata.
kwargs:
type: object
description: Additional keyword arguments for the workflow.
Expand All @@ -395,9 +398,12 @@ async def _run_workflow(self, request: Request) -> JSONResponse:
description: Error running workflow or invalid request body
"""
workflow = self._extract_workflow(request)
context, start_event, handler_id = await self._extract_run_params(
request, workflow.workflow, workflow.name
)
(
context,
start_event,
handler_id,
handler_metadata,
) = await self._extract_run_params(request, workflow.workflow, workflow.name)

if start_event is not None:
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
Expand All @@ -409,7 +415,9 @@ async def _run_workflow(self, request: Request) -> JSONResponse:
ctx=context,
start_event=input_ev,
)
wrapper = self._run_workflow_handler(handler_id, workflow.name, handler)
wrapper = self._run_workflow_handler(
handler_id, workflow.name, handler, handler_metadata
)
await handler
return JSONResponse(wrapper.to_dict())
except Exception as e:
Expand Down Expand Up @@ -554,9 +562,12 @@ async def _run_workflow_nowait(self, request: Request) -> JSONResponse:
description: Workflow or handler identifier not found
"""
workflow = self._extract_workflow(request)
context, start_event, handler_id = await self._extract_run_params(
request, workflow.workflow, workflow.name
)
(
context,
start_event,
handler_id,
handler_metadata,
) = await self._extract_run_params(request, workflow.workflow, workflow.name)

if start_event is not None:
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
Expand All @@ -571,6 +582,7 @@ async def _run_workflow_nowait(self, request: Request) -> JSONResponse:
handler_id,
workflow.name,
handler,
handler_metadata,
)
return JSONResponse(wrapper.to_dict())

Expand Down Expand Up @@ -748,16 +760,129 @@ async def _get_handlers(self, request: Request) -> JSONResponse:
"""
---
summary: Get handlers
description: Returns all workflow handlers.
description: Returns all workflow handlers with optional filtering.
parameters:
- in: query
name: status
schema:
type: string
enum: [running, completed, failed, cancelled]
description: Filter by handler status
- in: query
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might make sense to make this consistent with the agent data query DSL https://developers.llamaindex.ai/python/cloud/llamaagents/agent-data-overview/#filter-dsl

I don't love that DSL, but there's not a lot of great options that don't explode in complexity. This DSL looks actually pretty similar? I would like to eventually move that agent data into the open source llamactl (or here?), so developing it out consistently should make it more re-usable and facilitate easier integration.

One difference here is that one is just using a POST to not have to get query parsing involved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're already parsing json in the query. Could just continue that pattern with the other dsl?

name: metadata_exact
schema:
type: string
description: JSON object for exact metadata field matching
- in: query
name: metadata_has_keys
schema:
type: string
description: Comma-separated list of required metadata keys
- in: query
name: metadata_contains
schema:
type: string
description: JSON object for case-insensitive substring matching on string metadata values
responses:
200:
description: List of handlers
content:
application/json:
schema:
$ref: '#/components/schemas/HandlersList'
400:
description: Invalid query parameters
"""
items = [wrapper.to_dict() for wrapper in self._handlers.values()]
# Get query parameters
query_params = request.query_params
status_filter = query_params.get("status")
metadata_exact_str = query_params.get("metadata_exact")
metadata_has_keys_str = query_params.get("metadata_has_keys")
metadata_contains_str = query_params.get("metadata_contains")

# Parse JSON parameters
metadata_exact = None
if metadata_exact_str:
try:
metadata_exact = json.loads(metadata_exact_str)
if not isinstance(metadata_exact, dict):
raise HTTPException(
detail="metadata_exact must be a JSON object",
status_code=400,
)
except json.JSONDecodeError:
raise HTTPException(
detail="Invalid JSON in metadata_exact parameter",
status_code=400,
)

metadata_contains = None
if metadata_contains_str:
try:
metadata_contains = json.loads(metadata_contains_str)
if not isinstance(metadata_contains, dict):
raise HTTPException(
detail="metadata_contains must be a JSON object",
status_code=400,
)
except json.JSONDecodeError:
raise HTTPException(
detail="Invalid JSON in metadata_contains parameter",
status_code=400,
)

metadata_has_keys = None
if metadata_has_keys_str:
metadata_has_keys = [
key.strip() for key in metadata_has_keys_str.split(",")
]

# Filter handlers
items = []
for wrapper in self._handlers.values():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't really supporting completed workflows 😢 . Generally this querying seems like it should be pushed down into the persistence layer, at least in the long term, and we should push some of the state here in the server to an "in memory" persistence layer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh dang, yea, once workflows complete they get removed from this list

# Apply status filter
if status_filter and wrapper.status != status_filter:
continue

# Apply metadata_exact filter
if metadata_exact is not None:
handler_metadata = wrapper.handler_metadata or {}
matches = all(
handler_metadata.get(key) == value
for key, value in metadata_exact.items()
)
if not matches:
continue

# Apply metadata_has_keys filter
if metadata_has_keys is not None:
handler_metadata = wrapper.handler_metadata or {}
has_all_keys = all(key in handler_metadata for key in metadata_has_keys)
if not has_all_keys:
continue

# Apply metadata_contains filter
if metadata_contains is not None:
handler_metadata = wrapper.handler_metadata or {}
matches = True
for key, search_value in metadata_contains.items():
if key not in handler_metadata:
matches = False
break
handler_value = handler_metadata[key]
if not isinstance(handler_value, str) or not isinstance(
search_value, str
):
matches = False
break
if search_value.lower() not in handler_value.lower():
matches = False
break
if not matches:
continue

items.append(wrapper.to_dict())

return JSONResponse({"handlers": items})

async def _post_event(self, request: Request) -> JSONResponse:
Expand Down Expand Up @@ -948,13 +1073,25 @@ def _extract_workflow(self, request: Request) -> _NamedWorkflow:

async def _extract_run_params(
self, request: Request, workflow: Workflow, workflow_name: str
) -> tuple[Context | None, StartEvent | None, str]:
) -> tuple[Context | None, StartEvent | None, str, dict | None]:
try:
body = await request.json()
context_data = body.get("context")
run_kwargs = body.get("kwargs", {})
start_event_data = body.get("start_event", run_kwargs)
handler_id = body.get("handler_id")
handler_metadata = body.get("handler_metadata")

# Convert handler_metadata to dict if it's a string
if handler_metadata is not None:
if isinstance(handler_metadata, str):
handler_metadata = json.loads(handler_metadata)

if not isinstance(handler_metadata, dict):
raise HTTPException(
detail=f"Invalid handler_metadata: should be a dict, but got {type(handler_metadata)}",
status_code=400,
)

# Extract custom StartEvent if present
start_event = None
Expand Down Expand Up @@ -999,9 +1136,16 @@ async def _extract_run_params(
raise HTTPException(detail="Handler not found", status_code=404)

context = Context.from_dict(workflow, persisted_handlers[0].ctx)
if (
handler_metadata is None
and persisted_handlers[0].handler_metadata_json
):
handler_metadata = json.loads(
persisted_handlers[0].handler_metadata_json
)

handler_id = handler_id or nanoid()
return (context, start_event, handler_id)
return (context, start_event, handler_id, handler_metadata)

except HTTPException:
# Re-raise HTTPExceptions as-is (like start_event validation errors)
Expand All @@ -1012,7 +1156,11 @@ async def _extract_run_params(
)

def _run_workflow_handler(
self, handler_id: str, workflow_name: str, handler: WorkflowHandler
self,
handler_id: str,
workflow_name: str,
handler: WorkflowHandler,
handler_metadata: dict[str, Any] | None = None,
) -> _WorkflowHandler:
"""
Streams events from the handler, persisting them, and pushing them to a queue.
Expand All @@ -1028,9 +1176,15 @@ async def checkpoint(status: Status) -> None:
backoffs = list(self._persistence_backoff)
while True:
try:
handler_metadata_str = (
json.dumps(handler_metadata)
if handler_metadata is not None
else None
)
await self._workflow_store.update(
PersistentHandler(
handler_id=handler_id,
handler_metadata_json=handler_metadata_str,
workflow_name=workflow_name,
status=status,
ctx=ctx,
Expand Down Expand Up @@ -1087,6 +1241,7 @@ async def checkpoint(status: Status) -> None:
task = asyncio.create_task(_stream_events(handler))
wrapper = _WorkflowHandler(
run_handler=handler,
handler_metadata=handler_metadata,
queue=queue,
task=task,
consumer_mutex=asyncio.Lock(),
Expand Down Expand Up @@ -1150,10 +1305,12 @@ class _WorkflowHandler:
started_at: datetime
updated_at: datetime
completed_at: datetime | None
handler_metadata: dict[str, Any] | None

def to_dict(self) -> HandlerDict:
return HandlerDict(
handler_id=self.handler_id,
handler_metadata=self.handler_metadata,
workflow_name=self.workflow_name,
run_id=self.run_handler.run_id,
status=self.status,
Expand Down
Loading
Loading