Skip to content

Commit a3db8d4

Browse files
refactor: clean up multimodal extraction with single-pass processing
- Replace 3 duplicate recursive functions with single _extract_multimodal_content() - Add MultimodalContent dataclass for clean return type - Use Python 3.10+ match/case for idiomatic pattern matching - Fix circular reference vulnerability with _seen set tracking - Fix Pydantic deprecation warning (use type(obj).model_fields) - Consolidate 3 unit tests into 1 comprehensive test - Simplify README to be user-facing - Remove python-dotenv dependency, use simple .env loader 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent a0b46ab commit a3db8d4

File tree

6 files changed

+198
-335
lines changed

6 files changed

+198
-335
lines changed

atomic-agents/atomic_agents/context/chat_history.py

Lines changed: 102 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,109 @@
1+
from __future__ import annotations
2+
13
import json
24
import uuid
5+
from dataclasses import dataclass, field
36
from enum import Enum
47
from pathlib import Path
5-
from typing import Dict, List, Optional, Type
8+
from typing import TYPE_CHECKING, Any
69

710
from instructor.processing.multimodal import PDF, Image, Audio
811
from pydantic import BaseModel, Field
912

1013
from atomic_agents.base.base_io_schema import BaseIOSchema
1114

15+
if TYPE_CHECKING:
16+
from typing import Type
1217

13-
INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)
1418

19+
MULTIMODAL_TYPES = (Image, Audio, PDF)
1520

16-
def _contains_multimodal(obj) -> bool:
17-
"""
18-
Recursively checks if an object contains any multimodal content.
1921

20-
Args:
21-
obj: The object to check.
22+
@dataclass
23+
class MultimodalContent:
24+
"""Result of extracting multimodal content from nested structures."""
2225

23-
Returns:
24-
bool: True if the object contains multimodal content, False otherwise.
25-
"""
26-
if isinstance(obj, INSTRUCTOR_MULTIMODAL_TYPES):
27-
return True
28-
elif isinstance(obj, list):
29-
return any(_contains_multimodal(item) for item in obj)
30-
elif isinstance(obj, dict):
31-
return any(_contains_multimodal(value) for value in obj.values())
32-
elif hasattr(obj, "__class__") and hasattr(obj.__class__, "model_fields"):
33-
# Pydantic model - check all fields
34-
for field_name in obj.__class__.model_fields:
35-
if hasattr(obj, field_name):
36-
if _contains_multimodal(getattr(obj, field_name)):
37-
return True
38-
return False
39-
return False
40-
41-
42-
def _extract_multimodal_objects(obj) -> List:
43-
"""
44-
Recursively extracts all multimodal objects from a nested structure.
26+
objects: list = field(default_factory=list)
27+
json_data: Any = None
4528

46-
Args:
47-
obj: The object to extract multimodal content from.
29+
@property
30+
def has_multimodal(self) -> bool:
31+
return len(self.objects) > 0
4832

49-
Returns:
50-
List: A list of all multimodal objects found.
33+
34+
def _extract_multimodal_content(obj: Any, _seen: set[int] | None = None) -> MultimodalContent:
5135
"""
52-
multimodal_objects = []
53-
54-
if isinstance(obj, INSTRUCTOR_MULTIMODAL_TYPES):
55-
multimodal_objects.append(obj)
56-
elif isinstance(obj, list):
57-
for item in obj:
58-
multimodal_objects.extend(_extract_multimodal_objects(item))
59-
elif isinstance(obj, dict):
60-
for value in obj.values():
61-
multimodal_objects.extend(_extract_multimodal_objects(value))
62-
elif hasattr(obj, "__class__") and hasattr(obj.__class__, "model_fields"):
63-
# Pydantic model - check all fields
64-
for field_name in obj.__class__.model_fields:
65-
if hasattr(obj, field_name):
66-
multimodal_objects.extend(_extract_multimodal_objects(getattr(obj, field_name)))
67-
68-
return multimodal_objects
69-
70-
71-
def _build_non_multimodal_dict(obj):
36+
Single-pass extraction of multimodal content from nested structures.
37+
38+
Returns both the multimodal objects and a JSON-serializable representation
39+
with multimodal content removed.
7240
"""
73-
Recursively builds a dictionary representation of an object,
74-
excluding multimodal content but preserving the structure.
41+
if _seen is None:
42+
_seen = set()
7543

76-
Args:
77-
obj: The object to convert.
44+
match obj:
45+
case Image() | Audio() | PDF():
46+
return MultimodalContent(objects=[obj], json_data=None)
7847

79-
Returns:
80-
The non-multimodal representation of the object, or None if the object
81-
is purely multimodal content.
82-
"""
83-
if isinstance(obj, INSTRUCTOR_MULTIMODAL_TYPES):
84-
# Multimodal content - exclude from JSON
85-
return None
86-
elif isinstance(obj, list):
87-
# Filter out multimodal items and recursively process others
88-
result = []
89-
for item in obj:
90-
processed = _build_non_multimodal_dict(item)
91-
if processed is not None:
92-
result.append(processed)
93-
return result if result else None
94-
elif isinstance(obj, dict):
95-
# Recursively process dict values, excluding multimodal
96-
result = {}
97-
for key, value in obj.items():
98-
processed = _build_non_multimodal_dict(value)
99-
if processed is not None:
100-
result[key] = processed
101-
return result if result else None
102-
elif hasattr(obj, "__class__") and hasattr(obj.__class__, "model_fields"):
103-
# Pydantic model - recursively process fields
104-
result = {}
105-
for field_name in obj.__class__.model_fields:
106-
if hasattr(obj, field_name):
107-
field_value = getattr(obj, field_name)
108-
processed = _build_non_multimodal_dict(field_value)
109-
if processed is not None:
110-
result[field_name] = processed
111-
return result if result else None
112-
else:
113-
# Primitive types or other objects - return as-is for JSON serialization
114-
# Handle Pydantic serialization for complex types
115-
if hasattr(obj, "model_dump"):
116-
return obj.model_dump()
117-
return obj
48+
case list():
49+
if id(obj) in _seen:
50+
return MultimodalContent()
51+
_seen.add(id(obj))
52+
53+
all_objects = []
54+
json_items = []
55+
for item in obj:
56+
result = _extract_multimodal_content(item, _seen)
57+
all_objects.extend(result.objects)
58+
if result.json_data is not None:
59+
json_items.append(result.json_data)
60+
61+
return MultimodalContent(
62+
objects=all_objects,
63+
json_data=json_items or None,
64+
)
65+
66+
case dict():
67+
if id(obj) in _seen:
68+
return MultimodalContent()
69+
_seen.add(id(obj))
70+
71+
all_objects = []
72+
json_dict = {}
73+
for key, value in obj.items():
74+
result = _extract_multimodal_content(value, _seen)
75+
all_objects.extend(result.objects)
76+
if result.json_data is not None:
77+
json_dict[key] = result.json_data
78+
79+
return MultimodalContent(
80+
objects=all_objects,
81+
json_data=json_dict or None,
82+
)
83+
84+
case BaseModel():
85+
if id(obj) in _seen:
86+
return MultimodalContent()
87+
_seen.add(id(obj))
88+
89+
all_objects = []
90+
json_dict = {}
91+
for field_name in type(obj).model_fields:
92+
result = _extract_multimodal_content(getattr(obj, field_name), _seen)
93+
all_objects.extend(result.objects)
94+
if result.json_data is not None:
95+
json_dict[field_name] = result.json_data
96+
97+
return MultimodalContent(
98+
objects=all_objects,
99+
json_data=json_dict or None,
100+
)
101+
102+
case _ if hasattr(obj, "model_dump"):
103+
return MultimodalContent(json_data=obj.model_dump())
104+
105+
case _:
106+
return MultimodalContent(json_data=obj)
118107

119108

120109
class Message(BaseModel):
@@ -129,30 +118,29 @@ class Message(BaseModel):
129118

130119
role: str
131120
content: BaseIOSchema
132-
turn_id: Optional[str] = None
121+
turn_id: str | None = None
133122

134123

135124
class ChatHistory:
136125
"""
137126
Manages the chat history for an AI agent.
138127
139128
Attributes:
140-
history (List[Message]): A list of messages representing the chat history.
141-
max_messages (Optional[int]): Maximum number of messages to keep in history.
142-
current_turn_id (Optional[str]): The ID of the current turn.
129+
history: A list of messages representing the chat history.
130+
max_messages: Maximum number of messages to keep in history.
131+
current_turn_id: The ID of the current turn.
143132
"""
144133

145-
def __init__(self, max_messages: Optional[int] = None):
134+
def __init__(self, max_messages: int | None = None):
146135
"""
147136
Initializes the ChatHistory with an empty history and optional constraints.
148137
149138
Args:
150-
max_messages (Optional[int]): Maximum number of messages to keep in history.
151-
When exceeded, oldest messages are removed first.
139+
max_messages: Maximum number of messages to keep. Oldest removed first.
152140
"""
153-
self.history: List[Message] = []
141+
self.history: list[Message] = []
154142
self.max_messages = max_messages
155-
self.current_turn_id: Optional[str] = None
143+
self.current_turn_id: str | None = None
156144

157145
def initialize_turn(self) -> None:
158146
"""
@@ -210,31 +198,16 @@ def get_history(self) -> List[Dict]:
210198
"""
211199
history = []
212200
for message in self.history:
213-
input_content = message.content
214-
215-
# Use recursive function to check for multimodal content at any depth
216-
has_multimodal = _contains_multimodal(input_content)
217-
218-
if has_multimodal:
219-
# For multimodal content: create mixed array with JSON + multimodal objects
220-
processed_content = []
221-
222-
# Build non-multimodal data recursively
223-
non_multimodal_data = _build_non_multimodal_dict(input_content)
224-
225-
# Add single JSON string if there are non-multimodal fields
226-
if non_multimodal_data:
227-
processed_content.append(json.dumps(non_multimodal_data, ensure_ascii=False))
228-
229-
# Extract all multimodal objects recursively and add them
230-
multimodal_objects = _extract_multimodal_objects(input_content)
231-
processed_content.extend(multimodal_objects)
232-
233-
history.append({"role": message.role, "content": processed_content})
201+
extracted = _extract_multimodal_content(message.content)
202+
203+
if extracted.has_multimodal:
204+
content = []
205+
if extracted.json_data:
206+
content.append(json.dumps(extracted.json_data, ensure_ascii=False))
207+
content.extend(extracted.objects)
208+
history.append({"role": message.role, "content": content})
234209
else:
235-
# No multimodal content: generate single cohesive JSON string
236-
content_json = input_content.model_dump_json()
237-
history.append({"role": message.role, "content": content_json})
210+
history.append({"role": message.role, "content": message.content.model_dump_json()})
238211

239212
return history
240213

0 commit comments

Comments
 (0)