1111""" 
1212
1313import  threading 
14- from  typing  import  Optional 
14+ from  typing  import  TYPE_CHECKING 
1515
1616from  ...hooks .registry  import  HookProvider , HookRegistry 
17- from  ...multiagent .base  import  MultiAgentBase 
1817from  ...session  import  SessionManager 
1918from  .multiagent_events  import  (
20-     AfterGraphInvocationEvent ,
19+     AfterMultiAgentInvocationEvent ,
2120    AfterNodeInvocationEvent ,
22-     BeforeGraphInvocationEvent ,
21+     BeforeMultiAgentInvocationEvent ,
2322    BeforeNodeInvocationEvent ,
2423    MultiAgentInitializationEvent ,
25-     MultiAgentState ,
2624)
27- from  .multiagent_state_adapter  import  MultiAgentAdapter 
2825
26+ if  TYPE_CHECKING :
27+     from  ...multiagent .base  import  MultiAgentBase 
2928
30- def  _get_multiagent_state (
31-     multiagent_state : Optional [MultiAgentState ],
32-     orchestrator : MultiAgentBase ,
33- ) ->  MultiAgentState :
34-     if  multiagent_state  is  not None :
35-         return  multiagent_state 
3629
37-     return  MultiAgentAdapter .create_multi_agent_state (orchestrator = orchestrator )
38- 
39- 
40- class  MultiAgentHook (HookProvider ):
30+ class  PersistentHook (HookProvider ):
4131    """Hook provider for automatic multi-agent session persistence. 
4232
4333    This hook automatically persists multi-agent orchestrator state at key 
4434    execution points to enable resumable execution after interruptions. 
4535
46-     Args: 
47-         session_manager: SessionManager instance for state persistence 
48-         session_id: Unique identifier for the session 
4936    """ 
5037
51-     def  __init__ (self , session_manager : SessionManager ,  session_id :  str ):
38+     def  __init__ (self , session_manager : SessionManager ):
5239        """Initialize the multi-agent persistence hook. 
5340
5441        Args: 
5542            session_manager: SessionManager instance for state persistence 
56-             session_id: Unique identifier for the session 
5743        """ 
5844        self ._session_manager  =  session_manager 
59-         self ._session_id  =  session_id 
6045        self ._lock  =  threading .RLock ()
6146
6247    def  register_hooks (self , registry : HookRegistry , ** kwargs : object ) ->  None :
@@ -67,40 +52,40 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None:
6752            **kwargs: Additional keyword arguments (unused) 
6853        """ 
6954        registry .add_callback (MultiAgentInitializationEvent , self ._on_initialization )
70-         registry .add_callback (BeforeGraphInvocationEvent , self ._on_before_graph )
55+         registry .add_callback (BeforeMultiAgentInvocationEvent , self ._on_before_multiagent )
7156        registry .add_callback (BeforeNodeInvocationEvent , self ._on_before_node )
7257        registry .add_callback (AfterNodeInvocationEvent , self ._on_after_node )
73-         registry .add_callback (AfterGraphInvocationEvent , self ._on_after_graph )
58+         registry .add_callback (AfterMultiAgentInvocationEvent , self ._on_after_multiagent )
7459
75-     def  _on_initialization (self , event : MultiAgentInitializationEvent ):
60+     # TODO: We can add **kwarg or invocation_state later if we need to persist 
61+     def  _on_initialization (self , event : MultiAgentInitializationEvent ) ->  None :
7662        """Persist state when multi-agent orchestrator initializes.""" 
77-         self ._persist (_get_multiagent_state ( event .state ,  event . orchestrator ) )
63+         self ._persist (event .orchestrator )
7864
79-     def  _on_before_graph (self , event : BeforeGraphInvocationEvent ) :
80-         """Hook called before graph execution starts .""" 
65+     def  _on_before_multiagent (self , event : BeforeMultiAgentInvocationEvent )  ->   None :
66+         """Persist state when multi-agent orchestrator initializes .""" 
8167        pass 
8268
83-     def  _on_before_node (self , event : BeforeNodeInvocationEvent ):
69+     def  _on_before_node (self , event : BeforeNodeInvocationEvent )  ->   None :
8470        """Hook called before individual node execution.""" 
8571        pass 
8672
87-     def  _on_after_node (self , event : AfterNodeInvocationEvent ):
73+     def  _on_after_node (self , event : AfterNodeInvocationEvent )  ->   None :
8874        """Persist state after each node completes execution.""" 
89-         multi_agent_state  =  _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
90-         self ._persist (multi_agent_state )
75+         self ._persist (event .orchestrator )
9176
92-     def  _on_after_graph (self , event : AfterGraphInvocationEvent ) :
77+     def  _on_after_multiagent (self , event : AfterMultiAgentInvocationEvent )  ->   None :
9378        """Persist final state after graph execution completes.""" 
94-         multiagent_state  =  _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
95-         self ._persist (multiagent_state )
79+         self ._persist (event .orchestrator )
9680
97-     def  _persist (self , multiagent_state :  MultiAgentState ) ->  None :
81+     def  _persist (self , orchestrator :  "MultiAgentBase" ) ->  None :
9882        """Persist the provided MultiAgentState using the configured SessionManager. 
9983
10084        This method is synchronized across threads/tasks to avoid write races. 
10185
10286        Args: 
103-             multiagent_state : State to persist 
87+             orchestrator : State to persist 
10488        """ 
89+         current_state  =  orchestrator .get_state_from_orchestrator ()
10590        with  self ._lock :
106-             self ._session_manager .write_multi_agent_state ( multiagent_state )
91+             self ._session_manager .write_multi_agent_json ( current_state )
0 commit comments