Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _init_structure_tables(self):
conn = self._get_connection()

# Message structure with branch support
conn.execute("""
conn.execute(f"""
CREATE TABLE IF NOT EXISTS message_structure (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
Expand All @@ -71,13 +71,15 @@ def _init_structure_tables(self):
branch_turn_number INTEGER,
tool_name TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
FOREIGN KEY (session_id)
REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE,
FOREIGN KEY (message_id)
REFERENCES {self.messages_table}(id) ON DELETE CASCADE
)
""")

# Turn-level usage tracking with branch support and full JSON details
conn.execute("""
conn.execute(f"""
CREATE TABLE IF NOT EXISTS turn_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
Expand All @@ -90,7 +92,8 @@ def _init_structure_tables(self):
input_tokens_details JSON,
output_tokens_details JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
FOREIGN KEY (session_id)
REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE,
UNIQUE(session_id, branch_id, user_turn_number)
)
""")
Expand Down Expand Up @@ -160,9 +163,9 @@ def _get_all_items_sync():
with closing(conn.cursor()) as cursor:
if session_limit is None:
cursor.execute(
"""
f"""
SELECT m.message_data
FROM agent_messages m
FROM {self.messages_table} m
JOIN message_structure s ON m.id = s.message_id
WHERE m.session_id = ? AND s.branch_id = ?
ORDER BY s.sequence_number ASC
Expand All @@ -171,9 +174,9 @@ def _get_all_items_sync():
)
else:
cursor.execute(
"""
f"""
SELECT m.message_data
FROM agent_messages m
FROM {self.messages_table} m
JOIN message_structure s ON m.id = s.message_id
WHERE m.session_id = ? AND s.branch_id = ?
ORDER BY s.sequence_number DESC
Expand Down Expand Up @@ -206,9 +209,9 @@ def _get_items_sync():
# Get message IDs in correct order for this branch
if session_limit is None:
cursor.execute(
"""
f"""
SELECT m.message_data
FROM agent_messages m
FROM {self.messages_table} m
JOIN message_structure s ON m.id = s.message_id
WHERE m.session_id = ? AND s.branch_id = ?
ORDER BY s.sequence_number ASC
Expand All @@ -217,9 +220,9 @@ def _get_items_sync():
)
else:
cursor.execute(
"""
f"""
SELECT m.message_data
FROM agent_messages m
FROM {self.messages_table} m
JOIN message_structure s ON m.id = s.message_id
WHERE m.session_id = ? AND s.branch_id = ?
ORDER BY s.sequence_number DESC
Expand Down Expand Up @@ -439,7 +442,7 @@ def _add_structure_sync():
# Don't re-raise - structure metadata is supplementary

async def _cleanup_orphaned_messages(self) -> int:
"""Remove messages that exist in agent_messages but not in message_structure.
"""Remove messages that exist in the configured message table but not in message_structure.

This can happen if _add_structure_metadata fails after super().add_items() succeeds.
Used for maintaining data consistency.
Expand All @@ -453,9 +456,9 @@ def _cleanup_sync():
with closing(conn.cursor()) as cursor:
# Find messages without structure metadata
cursor.execute(
"""
f"""
SELECT am.id
FROM agent_messages am
FROM {self.messages_table} am
LEFT JOIN message_structure ms ON am.id = ms.message_id
WHERE am.session_id = ? AND ms.message_id IS NULL
""",
Expand All @@ -468,7 +471,8 @@ def _cleanup_sync():
# Delete orphaned messages
placeholders = ",".join("?" * len(orphaned_ids))
cursor.execute(
f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})",
orphaned_ids,
)

deleted_count = cursor.rowcount
Expand Down Expand Up @@ -587,10 +591,10 @@ def _validate_turn():
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
f"""
SELECT am.message_data
FROM message_structure ms
JOIN agent_messages am ON ms.message_id = am.id
JOIN {self.messages_table} am ON ms.message_id = am.id
WHERE ms.session_id = ? AND ms.branch_id = ?
AND ms.branch_turn_number = ? AND ms.message_type = 'user'
""",
Expand Down Expand Up @@ -920,13 +924,13 @@ def _get_turns_sync():
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
f"""
SELECT
ms.branch_turn_number,
am.message_data,
ms.created_at
FROM message_structure ms
JOIN agent_messages am ON ms.message_id = am.id
JOIN {self.messages_table} am ON ms.message_id = am.id
WHERE ms.session_id = ? AND ms.branch_id = ?
AND ms.message_type = 'user'
ORDER BY ms.branch_turn_number
Expand Down Expand Up @@ -975,13 +979,13 @@ def _search_sync():
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
f"""
SELECT
ms.branch_turn_number,
am.message_data,
ms.created_at
FROM message_structure ms
JOIN agent_messages am ON ms.message_id = am.id
JOIN {self.messages_table} am ON ms.message_id = am.id
WHERE ms.session_id = ? AND ms.branch_id = ?
AND ms.message_type = 'user'
AND am.message_data LIKE ?
Expand Down
46 changes: 46 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,52 @@ async def test_advanced_session_basic_functionality(agent: Agent):
session.close()


async def test_advanced_session_respects_custom_table_names():
"""AdvancedSQLiteSession should consistently use configured table names."""
session = AdvancedSQLiteSession(
session_id="advanced_custom_tables",
create_tables=True,
sessions_table="custom_agent_sessions",
messages_table="custom_agent_messages",
)

items: list[TResponseInputItem] = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "Let's do some math"},
{"role": "assistant", "content": "Sure"},
]
await session.add_items(items)

assert await session.get_items() == items

conversation_turns = await session.get_conversation_turns()
assert [turn["turn"] for turn in conversation_turns] == [1, 2]

matching_turns = await session.find_turns_by_content("math")
assert [turn["turn"] for turn in matching_turns] == [2]

conn = session._get_connection()
structure_foreign_keys = {
row[2] for row in conn.execute("PRAGMA foreign_key_list(message_structure)").fetchall()
}
usage_foreign_keys = {
row[2] for row in conn.execute("PRAGMA foreign_key_list(turn_usage)").fetchall()
}
assert structure_foreign_keys == {
session.messages_table,
session.sessions_table,
}
assert usage_foreign_keys == {session.sessions_table}

branch_name = await session.create_branch_from_turn(2, "custom_branch")
assert branch_name == "custom_branch"
assert await session.get_items() == items[:2]
assert await session.get_items(branch_id="main") == items

session.close()


async def test_message_structure_tracking(agent: Agent):
"""Test that message structure is properly tracked."""
session_id = "structure_test"
Expand Down
Loading