1+ from __future__ import annotations
2+
13import json
24import uuid
5+ from dataclasses import dataclass , field
36from enum import Enum
47from pathlib import Path
5- from typing import Dict , List , Optional , Type
8+ from typing import TYPE_CHECKING , Any
69
710from instructor .processing .multimodal import PDF , Image , Audio
811from pydantic import BaseModel , Field
912
1013from atomic_agents .base .base_io_schema import BaseIOSchema
1114
15+ if TYPE_CHECKING :
16+ from typing import Type
1217
13- INSTRUCTOR_MULTIMODAL_TYPES = (Image , Audio , PDF )
1418
19+ MULTIMODAL_TYPES = (Image , Audio , PDF )
1520
16- def _contains_multimodal (obj ) -> bool :
17- """
18- Recursively checks if an object contains any multimodal content.
1921
20- Args:
21- obj: The object to check.
22+ @dataclass
23+ class MultimodalContent :
24+ """Result of extracting multimodal content from nested structures."""
2225
23- Returns:
24- bool: True if the object contains multimodal content, False otherwise.
25- """
26- if isinstance (obj , INSTRUCTOR_MULTIMODAL_TYPES ):
27- return True
28- elif isinstance (obj , list ):
29- return any (_contains_multimodal (item ) for item in obj )
30- elif isinstance (obj , dict ):
31- return any (_contains_multimodal (value ) for value in obj .values ())
32- elif hasattr (obj , "__class__" ) and hasattr (obj .__class__ , "model_fields" ):
33- # Pydantic model - check all fields
34- for field_name in obj .__class__ .model_fields :
35- if hasattr (obj , field_name ):
36- if _contains_multimodal (getattr (obj , field_name )):
37- return True
38- return False
39- return False
40-
41-
42- def _extract_multimodal_objects (obj ) -> List :
43- """
44- Recursively extracts all multimodal objects from a nested structure.
26+ objects : list = field (default_factory = list )
27+ json_data : Any = None
4528
46- Args:
47- obj: The object to extract multimodal content from.
29+ @property
30+ def has_multimodal (self ) -> bool :
31+ return len (self .objects ) > 0
4832
49- Returns:
50- List: A list of all multimodal objects found.
33+
34+ def _extract_multimodal_content ( obj : Any , _seen : set [ int ] | None = None ) -> MultimodalContent :
5135 """
52- multimodal_objects = []
53-
54- if isinstance (obj , INSTRUCTOR_MULTIMODAL_TYPES ):
55- multimodal_objects .append (obj )
56- elif isinstance (obj , list ):
57- for item in obj :
58- multimodal_objects .extend (_extract_multimodal_objects (item ))
59- elif isinstance (obj , dict ):
60- for value in obj .values ():
61- multimodal_objects .extend (_extract_multimodal_objects (value ))
62- elif hasattr (obj , "__class__" ) and hasattr (obj .__class__ , "model_fields" ):
63- # Pydantic model - check all fields
64- for field_name in obj .__class__ .model_fields :
65- if hasattr (obj , field_name ):
66- multimodal_objects .extend (_extract_multimodal_objects (getattr (obj , field_name )))
67-
68- return multimodal_objects
69-
70-
71- def _build_non_multimodal_dict (obj ):
36+ Single-pass extraction of multimodal content from nested structures.
37+
38+ Returns both the multimodal objects and a JSON-serializable representation
39+ with multimodal content removed.
7240 """
73- Recursively builds a dictionary representation of an object,
74- excluding multimodal content but preserving the structure.
41+ if _seen is None :
42+ _seen = set ()
7543
76- Args:
77- obj: The object to convert.
44+ match obj :
45+ case Image () | Audio () | PDF ():
46+ return MultimodalContent (objects = [obj ], json_data = None )
7847
79- Returns:
80- The non-multimodal representation of the object, or None if the object
81- is purely multimodal content.
82- """
83- if isinstance (obj , INSTRUCTOR_MULTIMODAL_TYPES ):
84- # Multimodal content - exclude from JSON
85- return None
86- elif isinstance (obj , list ):
87- # Filter out multimodal items and recursively process others
88- result = []
89- for item in obj :
90- processed = _build_non_multimodal_dict (item )
91- if processed is not None :
92- result .append (processed )
93- return result if result else None
94- elif isinstance (obj , dict ):
95- # Recursively process dict values, excluding multimodal
96- result = {}
97- for key , value in obj .items ():
98- processed = _build_non_multimodal_dict (value )
99- if processed is not None :
100- result [key ] = processed
101- return result if result else None
102- elif hasattr (obj , "__class__" ) and hasattr (obj .__class__ , "model_fields" ):
103- # Pydantic model - recursively process fields
104- result = {}
105- for field_name in obj .__class__ .model_fields :
106- if hasattr (obj , field_name ):
107- field_value = getattr (obj , field_name )
108- processed = _build_non_multimodal_dict (field_value )
109- if processed is not None :
110- result [field_name ] = processed
111- return result if result else None
112- else :
113- # Primitive types or other objects - return as-is for JSON serialization
114- # Handle Pydantic serialization for complex types
115- if hasattr (obj , "model_dump" ):
116- return obj .model_dump ()
117- return obj
48+ case list ():
49+ if id (obj ) in _seen :
50+ return MultimodalContent ()
51+ _seen .add (id (obj ))
52+
53+ all_objects = []
54+ json_items = []
55+ for item in obj :
56+ result = _extract_multimodal_content (item , _seen )
57+ all_objects .extend (result .objects )
58+ if result .json_data is not None :
59+ json_items .append (result .json_data )
60+
61+ return MultimodalContent (
62+ objects = all_objects ,
63+ json_data = json_items or None ,
64+ )
65+
66+ case dict ():
67+ if id (obj ) in _seen :
68+ return MultimodalContent ()
69+ _seen .add (id (obj ))
70+
71+ all_objects = []
72+ json_dict = {}
73+ for key , value in obj .items ():
74+ result = _extract_multimodal_content (value , _seen )
75+ all_objects .extend (result .objects )
76+ if result .json_data is not None :
77+ json_dict [key ] = result .json_data
78+
79+ return MultimodalContent (
80+ objects = all_objects ,
81+ json_data = json_dict or None ,
82+ )
83+
84+ case BaseModel ():
85+ if id (obj ) in _seen :
86+ return MultimodalContent ()
87+ _seen .add (id (obj ))
88+
89+ all_objects = []
90+ json_dict = {}
91+ for field_name in type (obj ).model_fields :
92+ result = _extract_multimodal_content (getattr (obj , field_name ), _seen )
93+ all_objects .extend (result .objects )
94+ if result .json_data is not None :
95+ json_dict [field_name ] = result .json_data
96+
97+ return MultimodalContent (
98+ objects = all_objects ,
99+ json_data = json_dict or None ,
100+ )
101+
102+ case _ if hasattr (obj , "model_dump" ):
103+ return MultimodalContent (json_data = obj .model_dump ())
104+
105+ case _:
106+ return MultimodalContent (json_data = obj )
118107
119108
120109class Message (BaseModel ):
@@ -129,30 +118,29 @@ class Message(BaseModel):
129118
130119 role : str
131120 content : BaseIOSchema
132- turn_id : Optional [ str ] = None
121+ turn_id : str | None = None
133122
134123
135124class ChatHistory :
136125 """
137126 Manages the chat history for an AI agent.
138127
139128 Attributes:
140- history (List[Message]) : A list of messages representing the chat history.
141- max_messages (Optional[int]) : Maximum number of messages to keep in history.
142- current_turn_id (Optional[str]) : The ID of the current turn.
129+ history: A list of messages representing the chat history.
130+ max_messages: Maximum number of messages to keep in history.
131+ current_turn_id: The ID of the current turn.
143132 """
144133
145- def __init__ (self , max_messages : Optional [ int ] = None ):
134+ def __init__ (self , max_messages : int | None = None ):
146135 """
147136 Initializes the ChatHistory with an empty history and optional constraints.
148137
149138 Args:
150- max_messages (Optional[int]): Maximum number of messages to keep in history.
151- When exceeded, oldest messages are removed first.
139+ max_messages: Maximum number of messages to keep. Oldest removed first.
152140 """
153- self .history : List [Message ] = []
141+ self .history : list [Message ] = []
154142 self .max_messages = max_messages
155- self .current_turn_id : Optional [ str ] = None
143+ self .current_turn_id : str | None = None
156144
157145 def initialize_turn (self ) -> None :
158146 """
@@ -210,31 +198,16 @@ def get_history(self) -> List[Dict]:
210198 """
211199 history = []
212200 for message in self .history :
213- input_content = message .content
214-
215- # Use recursive function to check for multimodal content at any depth
216- has_multimodal = _contains_multimodal (input_content )
217-
218- if has_multimodal :
219- # For multimodal content: create mixed array with JSON + multimodal objects
220- processed_content = []
221-
222- # Build non-multimodal data recursively
223- non_multimodal_data = _build_non_multimodal_dict (input_content )
224-
225- # Add single JSON string if there are non-multimodal fields
226- if non_multimodal_data :
227- processed_content .append (json .dumps (non_multimodal_data , ensure_ascii = False ))
228-
229- # Extract all multimodal objects recursively and add them
230- multimodal_objects = _extract_multimodal_objects (input_content )
231- processed_content .extend (multimodal_objects )
232-
233- history .append ({"role" : message .role , "content" : processed_content })
201+ extracted = _extract_multimodal_content (message .content )
202+
203+ if extracted .has_multimodal :
204+ content = []
205+ if extracted .json_data :
206+ content .append (json .dumps (extracted .json_data , ensure_ascii = False ))
207+ content .extend (extracted .objects )
208+ history .append ({"role" : message .role , "content" : content })
234209 else :
235- # No multimodal content: generate single cohesive JSON string
236- content_json = input_content .model_dump_json ()
237- history .append ({"role" : message .role , "content" : content_json })
210+ history .append ({"role" : message .role , "content" : message .content .model_dump_json ()})
238211
239212 return history
240213
0 commit comments