diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 2d144657fe..200fdd8df2 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -2,13 +2,14 @@ import logging import uuid -from collections.abc import Callable, Sequence +from collections.abc import Callable, MutableMapping, Sequence from dataclasses import dataclass, field from typing import Any, ClassVar +from .._serialization import SerializationMixin from ._const import INTERNAL_SOURCE_ID from ._executor import Executor -from ._model_utils import DictConvertible, encode_value +from ._model_utils import encode_value logger = logging.getLogger(__name__) @@ -62,7 +63,7 @@ def _raise(*_: Any, **__: Any) -> Any: @dataclass(init=False) -class Edge(DictConvertible): +class Edge(SerializationMixin): """Model a directed, optionally-conditional hand-off between two executors. Each `Edge` captures the minimal metadata required to move a message from @@ -164,7 +165,7 @@ def should_route(self, data: Any) -> bool: return True return self._condition(data) - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Produce a JSON-serialisable view of the edge metadata. The representation includes the source and target executor identifiers @@ -184,7 +185,7 @@ def to_dict(self) -> dict[str, Any]: return payload @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Edge": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "Edge": """Reconstruct an `Edge` from its serialised dictionary form. The deserialised edge will lack the executable predicate because we do @@ -259,7 +260,7 @@ def __init__(self) -> None: @dataclass(init=False) -class EdgeGroup(DictConvertible): +class EdgeGroup(SerializationMixin): """Bundle edges that share a common routing semantics under a single id. The workflow runtime manipulates `EdgeGroup` instances rather than raw @@ -342,7 +343,7 @@ def target_executor_ids(self) -> list[str]: """ return list(dict.fromkeys(edge.target_id for edge in self.edges)) - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialise the group metadata and contained edges into primitives. The payload captures each edge through its own `to_dict` call, enabling @@ -385,7 +386,7 @@ class CustomGroup(EdgeGroup): return subclass @classmethod - def from_dict(cls, data: dict[str, Any]) -> "EdgeGroup": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "EdgeGroup": """Hydrate the correct `EdgeGroup` subclass from serialised state. The method inspects the `type` field, allocates the corresponding class @@ -556,7 +557,7 @@ def selection_func(self) -> Callable[[Any, list[str]], list[str]] | None: """ return self._selection_func - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialise the fan-out group while preserving selection metadata. In addition to the base `EdgeGroup` payload we embed the human-friendly @@ -569,7 +570,7 @@ def to_dict(self) -> dict[str, Any]: snapshot = group.to_dict() assert snapshot["selection_func_name"] == "" """ - payload = super().to_dict() + payload = super().to_dict(**kwargs) payload["selection_func_name"] = self.selection_func_name return payload @@ -610,7 +611,7 @@ def __init__(self, source_ids: Sequence[str], target_id: str, *, id: str | None @dataclass(init=False) -class SwitchCaseEdgeGroupCase(DictConvertible): +class SwitchCaseEdgeGroupCase(SerializationMixin): """Persistable description of a single conditional branch in a switch-case. Unlike the runtime `Case` object this serialisable variant stores only the @@ -684,7 +685,7 @@ def condition(self) -> Callable[[Any], bool]: """ return self._condition - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialise the case metadata without the executable predicate. Examples: @@ -699,7 +700,7 @@ def to_dict(self) -> dict[str, Any]: return payload @classmethod - def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupCase": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "SwitchCaseEdgeGroupCase": """Instantiate a case from its serialised dictionary payload. Examples: @@ -717,7 +718,7 @@ def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupCase": @dataclass(init=False) -class SwitchCaseEdgeGroupDefault(DictConvertible): +class SwitchCaseEdgeGroupDefault(SerializationMixin): """Persistable descriptor for the fallback branch of a switch-case group. The default branch is guaranteed to exist and is invoked when every other @@ -741,7 +742,7 @@ def __init__(self, target_id: str) -> None: self.target_id = target_id self.type = "Default" - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialise the default branch metadata for persistence or logging. Examples: @@ -753,7 +754,9 @@ def to_dict(self) -> dict[str, Any]: return {"target_id": self.target_id, "type": self.type} @classmethod - def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupDefault": + def from_dict( + cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any + ) -> "SwitchCaseEdgeGroupDefault": """Recreate the default branch from its persisted form. Examples: @@ -844,7 +847,7 @@ def selection_func(message: Any, targets: list[str]) -> list[str]: self.selection_func_name = None # type: ignore[attr-defined] self.cases = list(cases) - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialise the switch-case group, capturing all case descriptors. Each case is converted using `encode_value` to respect dataclass @@ -863,7 +866,7 @@ def to_dict(self) -> dict[str, Any]: snapshot = group.to_dict() assert len(snapshot["cases"]) == 2 """ - payload = super().to_dict() + payload = super().to_dict(**kwargs) payload["cases"] = [encode_value(case) for case in self.cases] return payload diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 1563dd7c53..77fd493bc4 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar +from .._serialization import SerializationMixin from ..observability import create_processing_span from ._events import ( ExecutorCompletedEvent, @@ -15,7 +16,6 @@ WorkflowErrorDetails, _framework_event_origin, # type: ignore[reportPrivateUsage] ) -from ._model_utils import DictConvertible from ._request_info_mixin import RequestInfoMixin from ._runner_context import Message, MessageType, RunnerContext from ._shared_state import SharedState @@ -26,7 +26,7 @@ # region Executor -class Executor(RequestInfoMixin, DictConvertible): +class Executor(RequestInfoMixin, SerializationMixin): """Base class for all workflow executors that process messages and perform computations. ## Overview @@ -422,7 +422,7 @@ def workflow_output_types(self) -> list[type[Any]]: return list(output_types) - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialize executor definition for workflow topology export.""" return {"id": self.id, "type": self.type} diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 9d21391ad8..84f8874288 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -7,7 +7,7 @@ import re import sys from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import AsyncIterable, MutableMapping, Sequence from dataclasses import dataclass, field from enum import Enum from typing import Any, Protocol, TypeVar, Union, cast @@ -23,6 +23,7 @@ FunctionResultContent, Role, ) +from agent_framework._serialization import SerializationMixin from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator from ._checkpoint import CheckpointStorage, WorkflowCheckpoint @@ -38,7 +39,7 @@ group_chat_orchestrator, ) from ._message_utils import normalize_messages_input -from ._model_utils import DictConvertible, encode_value +from ._model_utils import encode_value from ._participant_utils import GroupChatParticipantSpec, participant_description from ._request_info_mixin import response_handler from ._workflow import Workflow, WorkflowRunResult @@ -328,7 +329,7 @@ def _new_chat_message_list() -> list[ChatMessage]: @dataclass -class _MagenticStartMessage(DictConvertible): +class _MagenticStartMessage(SerializationMixin): """Internal: A message to start a magentic workflow.""" messages: list[ChatMessage] = field(default_factory=_new_chat_message_list) @@ -356,7 +357,7 @@ def from_string(cls, task_text: str) -> "_MagenticStartMessage": """Create a MagenticStartMessage from a simple string.""" return cls(task_text) - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Create a dict representation of the message.""" return { "messages": [message.to_dict() for message in self.messages], @@ -364,7 +365,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticStartMessage": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticStartMessage": """Create from a dict.""" if "messages" in data: raw_messages = data["messages"] @@ -446,17 +447,17 @@ class _MagenticPlanReviewReply: @dataclass -class _MagenticTaskLedger(DictConvertible): +class _MagenticTaskLedger(SerializationMixin): """Internal: Task ledger for the Standard Magentic manager.""" facts: ChatMessage plan: ChatMessage - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: return {"facts": _message_to_payload(self.facts), "plan": _message_to_payload(self.plan)} @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticTaskLedger": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticTaskLedger": return cls( facts=_message_from_payload(data.get("facts")), plan=_message_from_payload(data.get("plan")), @@ -464,17 +465,19 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticTaskLedger": @dataclass -class _MagenticProgressLedgerItem(DictConvertible): +class _MagenticProgressLedgerItem(SerializationMixin): """Internal: A progress ledger item.""" reason: str answer: str | bool - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: return {"reason": self.reason, "answer": self.answer} @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": + def from_dict( + cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any + ) -> "_MagenticProgressLedgerItem": answer_value = data.get("answer") if not isinstance(answer_value, (str, bool)): answer_value = "" # Default to empty string if not str or bool @@ -482,7 +485,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": @dataclass -class _MagenticProgressLedger(DictConvertible): +class _MagenticProgressLedger(SerializationMixin): """Internal: A progress ledger for tracking workflow progress.""" is_request_satisfied: _MagenticProgressLedgerItem @@ -491,7 +494,7 @@ class _MagenticProgressLedger(DictConvertible): next_speaker: _MagenticProgressLedgerItem instruction_or_question: _MagenticProgressLedgerItem - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: return { "is_request_satisfied": self.is_request_satisfied.to_dict(), "is_in_loop": self.is_in_loop.to_dict(), @@ -501,7 +504,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticProgressLedger": return cls( is_request_satisfied=_MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})), is_in_loop=_MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})), @@ -512,7 +515,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": @dataclass -class MagenticContext(DictConvertible): +class MagenticContext(SerializationMixin): """Context for the Magentic manager.""" task: ChatMessage @@ -522,7 +525,7 @@ class MagenticContext(DictConvertible): stall_count: int = 0 reset_count: int = 0 - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: return { "task": _message_to_payload(self.task), "chat_history": [_message_to_payload(msg) for msg in self.chat_history], @@ -533,7 +536,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "MagenticContext": + def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "MagenticContext": chat_history_payload = data.get("chat_history", []) history: list[ChatMessage] = [] for item in chat_history_payload: @@ -557,6 +560,12 @@ def reset(self) -> None: self.stall_count = 0 self.reset_count += 1 + def clone(self, *, deep: bool = True) -> Self: + """Create a copy of this context.""" + import copy + + return copy.deepcopy(self) if deep else copy.copy(self) # type: ignore[return-value] + # endregion Messages and Types @@ -2418,7 +2427,8 @@ async def _validate_checkpoint_participants( if not isinstance(orchestrator_state, dict): return - context_payload = orchestrator_state.get("magentic_context") + orchestrator_state_dict = cast(dict[str, Any], orchestrator_state) + context_payload = orchestrator_state_dict.get("magentic_context") if not isinstance(context_payload, dict): return diff --git a/python/packages/core/agent_framework/_workflows/_model_utils.py b/python/packages/core/agent_framework/_workflows/_model_utils.py index 72380901c6..372cdfc5ed 100644 --- a/python/packages/core/agent_framework/_workflows/_model_utils.py +++ b/python/packages/core/agent_framework/_workflows/_model_utils.py @@ -1,48 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. -import copy -import sys -from typing import Any, TypeVar, cast +from typing import Any -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - -TModel = TypeVar("TModel", bound="DictConvertible") - - -class DictConvertible: - """Mixin providing conversion helpers for plain Python models.""" - - def to_dict(self) -> dict[str, Any]: - raise NotImplementedError - - @classmethod - def from_dict(cls: type[TModel], data: dict[str, Any]) -> TModel: - return cls(**data) # type: ignore[arg-type] - - def clone(self, *, deep: bool = True) -> Self: - return copy.deepcopy(self) if deep else copy.copy(self) # type: ignore[return-value] - - def to_json(self) -> str: - import json - - return json.dumps(self.to_dict()) - - @classmethod - def from_json(cls: type[TModel], raw: str) -> TModel: - import json - - data = json.loads(raw) - if not isinstance(data, dict): - raise ValueError("JSON payload must decode to a mapping") - return cls.from_dict(cast(dict[str, Any], data)) +from .._serialization import SerializationProtocol def encode_value(value: Any) -> Any: """Recursively encode values for JSON-friendly serialization.""" - if isinstance(value, DictConvertible): + if isinstance(value, SerializationProtocol): return value.to_dict() if isinstance(value, dict): return {k: encode_value(v) for k, v in value.items()} # type: ignore[misc] diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index e5fd02a611..c6b398166a 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -10,6 +10,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable from typing import Any +from .._serialization import SerializationMixin from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent import WorkflowAgent from ._checkpoint import CheckpointStorage @@ -30,7 +31,6 @@ _framework_event_origin, # type: ignore ) from ._executor import Executor -from ._model_utils import DictConvertible from ._runner import Runner from ._runner_context import RunnerContext from ._shared_state import SharedState @@ -105,7 +105,7 @@ def status_timeline(self) -> list[WorkflowStatusEvent]: # region Workflow -class Workflow(DictConvertible): +class Workflow(SerializationMixin): """A graph-based execution engine that orchestrates connected executors. ## Overview @@ -237,7 +237,7 @@ def _reset_running_flag(self) -> None: """Reset the running flag.""" self._is_running = False - def to_dict(self) -> dict[str, Any]: + def to_dict(self, **kwargs: Any) -> dict[str, Any]: """Serialize the workflow definition into a JSON-ready dictionary.""" data: dict[str, Any] = { "id": self.id, @@ -269,10 +269,6 @@ def to_dict(self) -> dict[str, Any]: return data - def to_json(self) -> str: - """Serialize the workflow definition to JSON.""" - return json.dumps(self.to_dict()) - def get_start_executor(self) -> Executor: """Get the starting executor of the workflow.