diff --git a/camel/agents/_utils.py b/camel/agents/_utils.py index edae576813..415ab0a0ee 100644 --- a/camel/agents/_utils.py +++ b/camel/agents/_utils.py @@ -25,6 +25,44 @@ logger = logging.getLogger(__name__) +def build_default_summary_prompt(conversation_text: str) -> str: + r"""Create the default prompt used for conversation summarization. + + Args: + conversation_text (str): The conversation to be summarized. + + Returns: + str: A formatted prompt instructing the model to produce a structured + markdown summary. + """ + template = textwrap.dedent( + """\ + Summarize the conversation below. + Produce markdown that strictly follows this outline and numbering: + + Summary: + 1. **Primary Request and Intent**: + 2. **Key Concepts**: + 3. **Errors and Fixes**: + 4. **Problem Solving**: + 5. **Pending Tasks**: + 6. **Current Work**: + 7. **Optional Next Step**: + + Requirements: + - Use bullet lists under each section (`- item`). If a section has no + information, output `- None noted`. + - Keep the ordering, headings, and formatting as written above. + - Focus on concrete actions, findings, and decisions. + - Do not invent details that are not supported by the conversation. + + Conversation: + {conversation_text} + """ + ) + return template.format(conversation_text=conversation_text) + + def generate_tool_prompt(tool_schema_list: List[Dict[str, Any]]) -> str: r"""Generates a tool prompt based on the provided tool schema list. diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index 037d19aa45..132979ada4 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -20,7 +20,6 @@ import hashlib import inspect import json -import math import os import random import re @@ -57,6 +56,7 @@ from camel.agents._types import ModelResponse, ToolCallRequest from camel.agents._utils import ( + build_default_summary_prompt, convert_to_function_tool, convert_to_schema, get_info_dict, @@ -68,6 +68,7 @@ from camel.memories import ( AgentMemory, ChatHistoryMemory, + ContextRecord, MemoryRecord, ScoreBasedContextCreator, ) @@ -102,6 +103,16 @@ from camel.utils.commons import dependencies_required from camel.utils.context_utils import ContextUtility +TOKEN_LIMIT_ERROR_MARKERS = ( + "context_length_exceeded", + "prompt is too long", + "exceeded your current quota", + "tokens must be reduced", + "context length", + "token count", + "context limit", +) + if TYPE_CHECKING: from camel.terminators import ResponseTerminator @@ -354,9 +365,9 @@ class ChatAgent(BaseAgent): message_window_size (int, optional): The maximum number of previous messages to include in the context window. If `None`, no windowing is performed. (default: :obj:`None`) - token_limit (int, optional): The maximum number of tokens in a context. - The context will be automatically pruned to fulfill the limitation. - If `None`, it will be set according to the backend model. + summarize_threshold (int, optional): The percentage of the context + window that triggers summarization. If `None`, will trigger + summarization when the context window is full. (default: :obj:`None`) output_language (str, optional): The language to be output by the agent. (default: :obj:`None`) @@ -414,6 +425,10 @@ class ChatAgent(BaseAgent): updates return accumulated content (current behavior). When False, partial updates return only the incremental delta. (default: :obj:`True`) + summary_window_ratio (float, optional): Maximum fraction of the total + context window that can be occupied by summary information. Used + to limit how much of the model's context is reserved for + summarization results. (default: :obj:`0.6`) """ def __init__( @@ -436,6 +451,7 @@ def __init__( ] = None, memory: Optional[AgentMemory] = None, message_window_size: Optional[int] = None, + summarize_threshold: Optional[int] = 50, token_limit: Optional[int] = None, output_language: Optional[str] = None, tools: Optional[List[Union[FunctionTool, Callable]]] = None, @@ -458,6 +474,7 @@ def __init__( retry_delay: float = 1.0, step_timeout: Optional[float] = None, stream_accumulate: bool = True, + summary_window_ratio: float = 0.6, ) -> None: if isinstance(model, ModelManager): self.model_backend = model @@ -476,7 +493,7 @@ def __init__( # Set up memory context_creator = ScoreBasedContextCreator( self.model_backend.token_counter, - token_limit or self.model_backend.token_limit, + self.model_backend.token_limit, ) self._memory: AgentMemory = memory or ChatHistoryMemory( @@ -501,6 +518,21 @@ def __init__( ) self.init_messages() + # Set up summarize threshold with validation + if summarize_threshold is not None: + if not (0 < summarize_threshold <= 100): + raise ValueError( + f"summarize_threshold must be between 0 and 100, " + f"got {summarize_threshold}" + ) + logger.info( + f"Automatic context compression is enabled. Will trigger " + f"summarization when context window exceeds " + f"{summarize_threshold}% of the total token limit." + ) + self.summarize_threshold = summarize_threshold + self._reset_summary_state() + # Set up role name and role type self.role_name: str = ( getattr(self.system_message, "role_name", None) or "assistant" @@ -548,11 +580,16 @@ def __init__( self._context_utility: Optional[ContextUtility] = None self._context_summary_agent: Optional["ChatAgent"] = None self.stream_accumulate = stream_accumulate + self._last_tool_call_record: Optional[ToolCallingRecord] = None + self._last_tool_call_signature: Optional[str] = None + self._last_token_limit_tool_signature: Optional[str] = None + self.summary_window_ratio = summary_window_ratio def reset(self): r"""Resets the :obj:`ChatAgent` to its initial state.""" self.terminated = False self.init_messages() + self._reset_summary_state() for terminator in self.response_terminators: terminator.reset() @@ -759,6 +796,329 @@ def _get_full_tool_schemas(self) -> List[Dict[str, Any]]: for func_tool in self._internal_tools.values() ] + @staticmethod + def _is_token_limit_error(error: Exception) -> bool: + r"""Return True when the exception message indicates a token limit.""" + error_message = str(error).lower() + return any( + marker in error_message for marker in TOKEN_LIMIT_ERROR_MARKERS + ) + + @staticmethod + def _is_tool_related_record(record: MemoryRecord) -> bool: + r"""Determine whether the given memory record + belongs to a tool call.""" + if record.role_at_backend in { + OpenAIBackendRole.TOOL, + OpenAIBackendRole.FUNCTION, + }: + return True + + if ( + record.role_at_backend == OpenAIBackendRole.ASSISTANT + and isinstance(record.message, FunctionCallingMessage) + ): + return True + + return False + + def _find_indices_to_remove_for_last_tool_pair( + self, recent_records: List[ContextRecord] + ) -> List[int]: + """Find indices of records that should be removed to clean up the most + recent incomplete tool interaction pair. + + This method identifies tool call/result pairs by tool_call_id and + returns the exact indices to remove, allowing non-contiguous deletions. + + Logic: + - If the last record is a tool result (TOOL/FUNCTION) with a + tool_call_id, find the matching assistant call anywhere in history + and return both indices. + - If the last record is an assistant tool call without a result yet, + return just that index. + - For normal messages (non tool-related): remove just the last one. + - Fallback: If no tool_call_id is available, use heuristic (last 2 if + tool-related, otherwise last 1). + + Returns: + List[int]: Indices to remove (may be non-contiguous). + """ + if not recent_records: + return [] + + last_idx = len(recent_records) - 1 + last_record = recent_records[last_idx].memory_record + + # Case A: Last is an ASSISTANT tool call with no result yet + if ( + last_record.role_at_backend == OpenAIBackendRole.ASSISTANT + and isinstance(last_record.message, FunctionCallingMessage) + and last_record.message.result is None + ): + return [last_idx] + + # Case B: Last is TOOL/FUNCTION result, try id-based pairing + if last_record.role_at_backend in { + OpenAIBackendRole.TOOL, + OpenAIBackendRole.FUNCTION, + }: + tool_id = None + if isinstance(last_record.message, FunctionCallingMessage): + tool_id = last_record.message.tool_call_id + + if tool_id: + for idx in range(len(recent_records) - 2, -1, -1): + rec = recent_records[idx].memory_record + if rec.role_at_backend != OpenAIBackendRole.ASSISTANT: + continue + + # Check if this assistant message contains the tool_call_id + matched = False + + # Case 1: FunctionCallingMessage (single tool call) + if isinstance(rec.message, FunctionCallingMessage): + if rec.message.tool_call_id == tool_id: + matched = True + + # Case 2: BaseMessage with multiple tool_calls in meta_dict + elif ( + hasattr(rec.message, "meta_dict") + and rec.message.meta_dict + ): + tool_calls_list = rec.message.meta_dict.get( + "tool_calls", [] + ) + if isinstance(tool_calls_list, list): + for tc in tool_calls_list: + if ( + isinstance(tc, dict) + and tc.get("id") == tool_id + ): + matched = True + break + + if matched: + # Return both assistant call and tool result indices + return [idx, last_idx] + + # Fallback: no tool_call_id, use heuristic + if self._is_tool_related_record(last_record): + # Remove last 2 (assume they are paired) + return [last_idx - 1, last_idx] if last_idx > 0 else [last_idx] + else: + return [last_idx] + + # Default: non tool-related tail => remove last one + return [last_idx] + + @staticmethod + def _serialize_tool_args(args: Dict[str, Any]) -> str: + try: + return json.dumps(args, ensure_ascii=False, sort_keys=True) + except TypeError: + return str(args) + + @classmethod + def _build_tool_signature( + cls, func_name: str, args: Dict[str, Any] + ) -> str: + args_repr = cls._serialize_tool_args(args) + return f"{func_name}:{args_repr}" + + def _describe_tool_call( + self, record: Optional[ToolCallingRecord] + ) -> Optional[str]: + if record is None: + return None + args_repr = self._serialize_tool_args(record.args) + return f"Tool `{record.tool_name}` invoked with arguments {args_repr}." + + def _update_last_tool_call_state( + self, record: Optional[ToolCallingRecord] + ) -> None: + """Track the most recent tool call and its identifying signature.""" + self._last_tool_call_record = record + if record is None: + self._last_tool_call_signature = None + return + + args = ( + record.args + if isinstance(record.args, dict) + else {"_raw": record.args} + ) + try: + signature = self._build_tool_signature(record.tool_name, args) + except Exception: # pragma: no cover - defensive guard + signature = None + self._last_tool_call_signature = signature + + def _format_tool_limit_notice(self) -> Optional[str]: + record = self._last_tool_call_record + description = self._describe_tool_call(record) + if description is None: + return None + notice_lines = [ + "[Tool Call Causing Token Limit]", + description, + ] + + if record is not None: + result = record.result + if isinstance(result, bytes): + result_repr = result.decode(errors="replace") + elif isinstance(result, str): + result_repr = result + else: + try: + result_repr = json.dumps( + result, ensure_ascii=False, sort_keys=True + ) + except (TypeError, ValueError): + result_repr = str(result) + + result_length = len(result_repr) + notice_lines.append(f"Tool result length: {result_length}") + if self.model_backend.token_limit != 999999999: + notice_lines.append( + f"Token limit: {self.model_backend.token_limit}" + ) + + return "\n".join(notice_lines) + + @staticmethod + def _append_user_messages_section( + summary_content: str, user_messages: List[str] + ) -> str: + section_title = "- **All User Messages**:" + sanitized_messages: List[str] = [] + for msg in user_messages: + if not isinstance(msg, str): + msg = str(msg) + cleaned = " ".join(msg.strip().splitlines()) + if cleaned: + sanitized_messages.append(cleaned) + + bullet_block = ( + "\n".join(f"- {m}" for m in sanitized_messages) + if sanitized_messages + else "- None noted" + ) + user_section = f"{section_title}\n{bullet_block}" + + summary_clean = summary_content.rstrip() + separator = "\n\n" if summary_clean else "" + return f"{summary_clean}{separator}{user_section}" + + def _reset_summary_state(self) -> None: + self._summary_token_count = 0 # Total tokens in summary messages + + def _calculate_next_summary_threshold(self) -> int: + r"""Calculate the next token threshold that should trigger + summarization. + + The threshold calculation follows a progressive strategy: + - First time: token_limit * (summarize_threshold / 100) + - Subsequent times: (limit - summary_token) / 2 + summary_token + + This ensures that as summaries accumulate, the threshold adapts + to maintain a reasonable balance between context and summaries. + + Returns: + int: The token count threshold for next summarization. + """ + token_limit = self.model_backend.token_limit + summary_token_count = self._summary_token_count + + # First summarization: use the percentage threshold + if summary_token_count == 0: + threshold = int(token_limit * self.summarize_threshold / 100) + else: + # Subsequent summarizations: adaptive threshold + threshold = int( + (token_limit - summary_token_count) + * self.summarize_threshold + / 100 + + summary_token_count + ) + + return threshold + + def _update_memory_with_summary( + self, summary: str, include_summaries: bool = False + ) -> None: + r"""Update memory with summary result. + + This method handles memory clearing and restoration of summaries based + on whether it's a progressive or full compression. + """ + + summary_content: str = summary + + existing_summaries = [] + if not include_summaries: + messages, _ = self.memory.get_context() + for msg in messages: + content = msg.get('content', '') + if isinstance(content, str) and content.startswith( + '[CONTEXT_SUMMARY]' + ): + existing_summaries.append(msg) + + # Clear memory + self.clear_memory() + + # Restore old summaries (for progressive compression) + for old_summary in existing_summaries: + content = old_summary.get('content', '') + if not isinstance(content, str): + content = str(content) + summary_msg = BaseMessage.make_assistant_message( + role_name="assistant", content=content + ) + self.update_memory(summary_msg, OpenAIBackendRole.ASSISTANT) + + # Add new summary + new_summary_msg = BaseMessage.make_assistant_message( + role_name="assistant", content=summary_content + ) + self.update_memory(new_summary_msg, OpenAIBackendRole.ASSISTANT) + input_message = BaseMessage.make_assistant_message( + role_name="assistant", + content=( + "Please continue the conversation from " + "where we left it off without asking the user any further " + "questions. Continue with the last task that you were " + "asked to work on." + ), + ) + self.update_memory(input_message, OpenAIBackendRole.ASSISTANT) + # Update token count + try: + summary_tokens = ( + self.model_backend.token_counter.count_tokens_from_messages( + [{"role": "assistant", "content": summary_content}] + ) + ) + + if include_summaries: # Full compression - reset count + self._summary_token_count = summary_tokens + logger.info( + f"Full compression: Summary with {summary_tokens} tokens. " + f"Total summary tokens reset to: {summary_tokens}" + ) + else: # Progressive compression - accumulate + self._summary_token_count += summary_tokens + logger.info( + f"Progressive compression: New summary " + f"with {summary_tokens} tokens. " + f"Total summary tokens: " + f"{self._summary_token_count}" + ) + except Exception as e: + logger.warning(f"Failed to count summary tokens: {e}") + def _get_external_tool_names(self) -> Set[str]: r"""Returns a set of external tool names.""" return set(self._external_tool_schemas.keys()) @@ -820,16 +1180,6 @@ def update_memory( ) -> None: r"""Updates the agent memory with a new message. - If the single *message* exceeds the model's context window, it will - be **automatically split into multiple smaller chunks** before being - written into memory. This prevents later failures in - `ScoreBasedContextCreator` where an over-sized message cannot fit - into the available token budget at all. - - This slicing logic handles both regular text messages (in the - `content` field) and long tool call results (in the `result` field of - a `FunctionCallingMessage`). - Args: message (BaseMessage): The new message to add to the stored messages. @@ -839,151 +1189,15 @@ def update_memory( (default: :obj:`None`) (default: obj:`None`) """ - - # 1. Helper to write a record to memory - def _write_single_record( - message: BaseMessage, role: OpenAIBackendRole, timestamp: float - ): - self.memory.write_record( - MemoryRecord( - message=message, - role_at_backend=role, - timestamp=timestamp, - agent_id=self.agent_id, - ) - ) - - base_ts = ( - timestamp + record = MemoryRecord( + message=message, + role_at_backend=role, + timestamp=timestamp if timestamp is not None - else time.time_ns() / 1_000_000_000 - ) - - # 2. Get token handling utilities, fallback if unavailable - try: - context_creator = self.memory.get_context_creator() - token_counter = context_creator.token_counter - token_limit = context_creator.token_limit - except AttributeError: - _write_single_record(message, role, base_ts) - return - - # 3. Check if slicing is necessary - try: - current_tokens = token_counter.count_tokens_from_messages( - [message.to_openai_message(role)] - ) - - _, ctx_tokens = self.memory.get_context() - - remaining_budget = max(0, token_limit - ctx_tokens) - - if current_tokens <= remaining_budget: - _write_single_record(message, role, base_ts) - return - except Exception as e: - logger.warning( - f"Token calculation failed before chunking, " - f"writing message as-is. Error: {e}" - ) - _write_single_record(message, role, base_ts) - return - - # 4. Perform slicing - logger.warning( - f"Message with {current_tokens} tokens exceeds remaining budget " - f"of {remaining_budget}. Slicing into smaller chunks." + else time.time_ns() / 1_000_000_000, # Nanosecond precision + agent_id=self.agent_id, ) - - text_to_chunk: Optional[str] = None - is_function_result = False - - if isinstance(message, FunctionCallingMessage) and isinstance( - message.result, str - ): - text_to_chunk = message.result - is_function_result = True - elif isinstance(message.content, str): - text_to_chunk = message.content - - if not text_to_chunk or not text_to_chunk.strip(): - _write_single_record(message, role, base_ts) - return - # Encode the entire text to get a list of all token IDs - try: - all_token_ids = token_counter.encode(text_to_chunk) - except Exception as e: - logger.error(f"Failed to encode text for chunking: {e}") - _write_single_record(message, role, base_ts) # Fallback - return - - if not all_token_ids: - _write_single_record(message, role, base_ts) # Nothing to chunk - return - - # 1. Base chunk size: one-tenth of the smaller of (a) total token - # limit and (b) current remaining budget. This prevents us from - # creating chunks that are guaranteed to overflow the - # immediate context window. - base_chunk_size = max(1, remaining_budget) // 10 - - # 2. Each chunk gets a textual prefix such as: - # "[chunk 3/12 of a long message]\n" - # The prefix itself consumes tokens, so if we do not subtract its - # length the *total* tokens of the outgoing message (prefix + body) - # can exceed the intended bound. We estimate the prefix length - # with a representative example that is safely long enough for the - # vast majority of cases (three-digit indices). - sample_prefix = "[chunk 1/1000 of a long message]\n" - prefix_token_len = len(token_counter.encode(sample_prefix)) - - # 3. The real capacity for the message body is therefore the base - # chunk size minus the prefix length. Fallback to at least one - # token to avoid zero or negative sizes. - chunk_body_limit = max(1, base_chunk_size - prefix_token_len) - - # 4. Calculate how many chunks we will need with this body size. - num_chunks = math.ceil(len(all_token_ids) / chunk_body_limit) - group_id = str(uuid.uuid4()) - - for i in range(num_chunks): - start_idx = i * chunk_body_limit - end_idx = start_idx + chunk_body_limit - chunk_token_ids = all_token_ids[start_idx:end_idx] - - chunk_body = token_counter.decode(chunk_token_ids) - - prefix = f"[chunk {i + 1}/{num_chunks} of a long message]\n" - new_body = prefix + chunk_body - - if is_function_result and isinstance( - message, FunctionCallingMessage - ): - new_msg: BaseMessage = FunctionCallingMessage( - role_name=message.role_name, - role_type=message.role_type, - meta_dict=message.meta_dict, - content=message.content, - func_name=message.func_name, - args=message.args, - result=new_body, - tool_call_id=message.tool_call_id, - ) - else: - new_msg = message.create_new_instance(new_body) - - meta = (new_msg.meta_dict or {}).copy() - meta.update( - { - "chunk_idx": i + 1, - "chunk_total": num_chunks, - "chunk_group_id": group_id, - } - ) - new_msg.meta_dict = meta - - # Increment timestamp slightly to maintain order - _write_single_record(new_msg, role, base_ts + i * 1e-6) + self.memory.write_record(record) def load_memory(self, memory: AgentMemory) -> None: r"""Load the provided memory into the agent. @@ -1068,6 +1282,8 @@ def summarize( summary_prompt: Optional[str] = None, response_format: Optional[Type[BaseModel]] = None, working_directory: Optional[Union[str, Path]] = None, + include_summaries: bool = False, + add_user_messages: bool = True, ) -> Dict[str, Any]: r"""Summarize the agent's current conversation context and persist it to a markdown file. @@ -1087,10 +1303,16 @@ def summarize( defining the expected structure of the response. If provided, the summary will be generated as structured output and included in the result. + include_summaries (bool): Whether to include previously generated + summaries in the content to be summarized. If False (default), + only non-summary messages will be summarized. If True, all + messages including previous summaries will be summarized + (full compression). (default: :obj:`False`) working_directory (Optional[str|Path]): Optional directory to save the markdown summary file. If provided, overrides the default directory used by ContextUtility. - + add_user_messages (bool): Whether add user messages to summary. + (default: :obj:`True`) Returns: Dict[str, Any]: A dictionary containing the summary text, file path, status message, and optionally structured_summary if @@ -1136,10 +1358,17 @@ def summarize( # Convert messages to conversation text conversation_lines = [] + user_messages: List[str] = [] for message in messages: role = message.get('role', 'unknown') content = message.get('content', '') + # Skip summary messages if include_summaries is False + if not include_summaries and isinstance(content, str): + # Check if this is a summary message by looking for marker + if content.startswith('[CONTEXT_SUMMARY]'): + continue + # Handle tool call messages (assistant calling tools) tool_calls = message.get('tool_calls') if tool_calls and isinstance(tool_calls, (list, tuple)): @@ -1191,6 +1420,9 @@ def summarize( # Handle regular content messages (user/assistant/system) elif content: + content = str(content) + if role == 'user': + user_messages.append(content) conversation_lines.append(f"{role}: {content}") conversation_text = "\n".join(conversation_lines).strip() @@ -1221,11 +1453,7 @@ def summarize( f"{conversation_text}" ) else: - prompt_text = ( - "Summarize the context information in concise markdown " - "bullet points highlighting key decisions, action items.\n" - f"Context information:\n{conversation_text}" - ) + prompt_text = build_default_summary_prompt(conversation_text) try: # Use structured output if response_format is provided @@ -1295,6 +1523,10 @@ def summarize( summary_content = context_util.structured_output_to_markdown( structured_data=structured_output, metadata=metadata ) + if add_user_messages: + summary_content = self._append_user_messages_section( + summary_content, user_messages + ) # Save the markdown (either custom structured or default) save_status = context_util.save_markdown_file( @@ -1309,7 +1541,10 @@ def summarize( file_path = ( context_util.get_working_directory() / f"{base_filename}.md" ) - + summary_content = ( + f"[CONTEXT_SUMMARY] The following is a summary of our " + f"conversation from a previous session: {summary_content}" + ) # Prepare result dictionary result_dict = { "summary": summary_content, @@ -1334,6 +1569,8 @@ async def asummarize( summary_prompt: Optional[str] = None, response_format: Optional[Type[BaseModel]] = None, working_directory: Optional[Union[str, Path]] = None, + include_summaries: bool = False, + add_user_messages: bool = True, ) -> Dict[str, Any]: r"""Asynchronously summarize the agent's current conversation context and persist it to a markdown file. @@ -1356,7 +1593,13 @@ async def asummarize( working_directory (Optional[str|Path]): Optional directory to save the markdown summary file. If provided, overrides the default directory used by ContextUtility. - + include_summaries (bool): Whether to include previously generated + summaries in the content to be summarized. If False (default), + only non-summary messages will be summarized. If True, all + messages including previous summaries will be summarized + (full compression). (default: :obj:`False`) + add_user_messages (bool): Whether add user messages to summary. + (default: :obj:`True`) Returns: Dict[str, Any]: A dictionary containing the summary text, file path, status message, and optionally structured_summary if @@ -1392,10 +1635,17 @@ async def asummarize( # Convert messages to conversation text conversation_lines = [] + user_messages: List[str] = [] for message in messages: role = message.get('role', 'unknown') content = message.get('content', '') + # Skip summary messages if include_summaries is False + if not include_summaries and isinstance(content, str): + # Check if this is a summary message by looking for marker + if content.startswith('[CONTEXT_SUMMARY]'): + continue + # Handle tool call messages (assistant calling tools) tool_calls = message.get('tool_calls') if tool_calls and isinstance(tool_calls, (list, tuple)): @@ -1447,6 +1697,9 @@ async def asummarize( # Handle regular content messages (user/assistant/system) elif content: + content = str(content) + if role == 'user': + user_messages.append(content) conversation_lines.append(f"{role}: {content}") conversation_text = "\n".join(conversation_lines).strip() @@ -1477,11 +1730,7 @@ async def asummarize( f"{conversation_text}" ) else: - prompt_text = ( - "Summarize the context information in concise markdown " - "bullet points highlighting key decisions, action items.\n" - f"Context information:\n{conversation_text}" - ) + prompt_text = build_default_summary_prompt(conversation_text) try: # Use structured output if response_format is provided @@ -1560,6 +1809,10 @@ async def asummarize( summary_content = context_util.structured_output_to_markdown( structured_data=structured_output, metadata=metadata ) + if add_user_messages: + summary_content = self._append_user_messages_section( + summary_content, user_messages + ) # Save the markdown (either custom structured or default) save_status = context_util.save_markdown_file( @@ -1575,6 +1828,11 @@ async def asummarize( context_util.get_working_directory() / f"{base_filename}.md" ) + summary_content = ( + f"[CONTEXT_SUMMARY] The following is a summary of our " + f"conversation from a previous session: {summary_content}" + ) + # Prepare result dictionary result_dict = { "summary": summary_content, @@ -1602,7 +1860,14 @@ def clear_memory(self) -> None: self.memory.clear() if self.system_message is not None: - self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM) + self.memory.write_record( + MemoryRecord( + message=self.system_message, + role_at_backend=OpenAIBackendRole.SYSTEM, + timestamp=time.time_ns() / 1_000_000_000, + agent_id=self.agent_id, + ) + ) def _generate_system_message_for_output_language( self, @@ -1633,17 +1898,8 @@ def init_messages(self) -> None: r"""Initializes the stored messages list with the current system message. """ - self.memory.clear() - # Write system message to memory if provided - if self.system_message is not None: - self.memory.write_record( - MemoryRecord( - message=self.system_message, - role_at_backend=OpenAIBackendRole.SYSTEM, - timestamp=time.time_ns() / 1_000_000_000, - agent_id=self.agent_id, - ) - ) + self._reset_summary_state() + self.clear_memory() def update_system_message( self, @@ -2123,22 +2379,122 @@ def _step_impl( try: openai_messages, num_tokens = self.memory.get_context() + if self.summarize_threshold is not None: + threshold = self._calculate_next_summary_threshold() + summary_token_count = self._summary_token_count + token_limit = self.model_backend.token_limit + + if num_tokens <= token_limit: + if ( + summary_token_count + > token_limit * self.summary_window_ratio + ): + logger.info( + f"Summary tokens ({summary_token_count}) " + f"exceed limit, full compression." + ) + # Summarize everything (including summaries) + summary = self.summarize(include_summaries=True) + self._update_memory_with_summary( + summary.get("summary", ""), + include_summaries=True, + ) + elif num_tokens > threshold: + logger.info( + f"Token count ({num_tokens}) exceed threshold " + f"({threshold}). Triggering summarization." + ) + # Only summarize non-summary content + summary = self.summarize(include_summaries=False) + self._update_memory_with_summary( + summary.get("summary", ""), + include_summaries=False, + ) accumulated_context_tokens += num_tokens except RuntimeError as e: return self._step_terminate( e.args[1], tool_call_records, "max_tokens_exceeded" ) - # Get response from model backend - response = self._get_model_response( - openai_messages, - num_tokens=num_tokens, - current_iteration=iteration_count, - response_format=response_format, - tool_schemas=[] - if disable_tools - else self._get_full_tool_schemas(), - prev_num_openai_messages=prev_num_openai_messages, - ) + # Get response from model backend with token limit error handling + try: + response = self._get_model_response( + openai_messages, + num_tokens=num_tokens, + current_iteration=iteration_count, + response_format=response_format, + tool_schemas=[] + if disable_tools + else self._get_full_tool_schemas(), + prev_num_openai_messages=prev_num_openai_messages, + ) + except Exception as exc: + logger.exception("Model error: %s", exc) + + if self._is_token_limit_error(exc): + tool_signature = self._last_tool_call_signature + if ( + tool_signature is not None + and tool_signature + == self._last_token_limit_tool_signature + ): + description = self._describe_tool_call( + self._last_tool_call_record + ) + repeated_msg = ( + "Context exceeded again by the same tool call." + ) + if description: + repeated_msg += f" {description}" + raise RuntimeError(repeated_msg) from exc + + user_message_count = sum( + 1 + for msg in openai_messages + if getattr(msg, "role", None) == "user" + ) + if ( + user_message_count == 1 + and getattr(openai_messages[-1], "role", None) + == "user" + ): + raise RuntimeError( + "The provided user input alone exceeds the " + "context window. Please shorten the input." + ) from exc + + logger.warning( + "Token limit exceeded error detected. " + "Summarizing context." + ) + + recent_records: List[ContextRecord] + try: + recent_records = self.memory.retrieve() + except Exception: # pragma: no cover - defensive guard + recent_records = [] + + indices_to_remove = ( + self._find_indices_to_remove_for_last_tool_pair( + recent_records + ) + ) + self.memory.remove_records_by_indices(indices_to_remove) + + summary = self.summarize(include_summaries=False) + tool_notice = self._format_tool_limit_notice() + summary_messages = summary.get("summary", "") + + if tool_notice: + summary_messages += "\n\n" + tool_notice + + self._update_memory_with_summary( + summary_messages, include_summaries=False + ) + self._last_token_limit_tool_signature = tool_signature + return self._step_impl(input_message, response_format) + + raise + prev_num_openai_messages = len(openai_messages) iteration_count += 1 @@ -2333,6 +2689,7 @@ async def _astep_non_streaming_task( step_token_usage = self._create_token_usage_tracker() iteration_count: int = 0 prev_num_openai_messages: int = 0 + while True: if self.pause_event is not None and not self.pause_event.is_set(): if isinstance(self.pause_event, asyncio.Event): @@ -2343,21 +2700,128 @@ async def _astep_non_streaming_task( await loop.run_in_executor(None, self.pause_event.wait) try: openai_messages, num_tokens = self.memory.get_context() + if self.summarize_threshold is not None: + threshold = self._calculate_next_summary_threshold() + summary_token_count = self._summary_token_count + token_limit = self.model_backend.token_limit + + if num_tokens <= token_limit: + if ( + summary_token_count + > token_limit * self.summary_window_ratio + ): + logger.info( + f"Summary tokens ({summary_token_count}) " + f"exceed limit, full compression." + ) + # Summarize everything (including summaries) + summary = await self.asummarize( + include_summaries=True + ) + self._update_memory_with_summary( + summary.get("summary", ""), + include_summaries=True, + ) + elif num_tokens > threshold: + logger.info( + f"Token count ({num_tokens}) exceed threshold " + "({threshold}). Triggering summarization." + ) + # Only summarize non-summary content + summary = await self.asummarize( + include_summaries=False + ) + self._update_memory_with_summary( + summary.get("summary", ""), + include_summaries=False, + ) accumulated_context_tokens += num_tokens except RuntimeError as e: return self._step_terminate( e.args[1], tool_call_records, "max_tokens_exceeded" ) - response = await self._aget_model_response( - openai_messages, - num_tokens=num_tokens, - current_iteration=iteration_count, - response_format=response_format, - tool_schemas=[] - if disable_tools - else self._get_full_tool_schemas(), - prev_num_openai_messages=prev_num_openai_messages, - ) + # Get response from model backend with token limit error handling + try: + response = await self._aget_model_response( + openai_messages, + num_tokens=num_tokens, + current_iteration=iteration_count, + response_format=response_format, + tool_schemas=[] + if disable_tools + else self._get_full_tool_schemas(), + prev_num_openai_messages=prev_num_openai_messages, + ) + except Exception as exc: + logger.exception("Model error: %s", exc) + + if self._is_token_limit_error(exc): + tool_signature = self._last_tool_call_signature + if ( + tool_signature is not None + and tool_signature + == self._last_token_limit_tool_signature + ): + description = self._describe_tool_call( + self._last_tool_call_record + ) + repeated_msg = ( + "Context exceeded again by the same tool call." + ) + if description: + repeated_msg += f" {description}" + raise RuntimeError(repeated_msg) from exc + + user_message_count = sum( + 1 + for msg in openai_messages + if getattr(msg, "role", None) == "user" + ) + if ( + user_message_count == 1 + and getattr(openai_messages[-1], "role", None) + == "user" + ): + raise RuntimeError( + "The provided user input alone exceeds the" + "context window. Please shorten the input." + ) from exc + + logger.warning( + "Token limit exceeded error detected. " + "Summarizing context." + ) + + recent_records: List[ContextRecord] + try: + recent_records = self.memory.retrieve() + except Exception: # pragma: no cover - defensive guard + recent_records = [] + + indices_to_remove = ( + self._find_indices_to_remove_for_last_tool_pair( + recent_records + ) + ) + self.memory.remove_records_by_indices(indices_to_remove) + + summary = await self.asummarize() + + tool_notice = self._format_tool_limit_notice() + summary_messages = summary.get("summary", "") + + if tool_notice: + summary_messages += "\n\n" + tool_notice + self._update_memory_with_summary( + summary_messages, include_summaries=False + ) + self._last_token_limit_tool_signature = tool_signature + return await self._astep_non_streaming_task( + input_message, response_format + ) + + raise + prev_num_openai_messages = len(openai_messages) iteration_count += 1 @@ -2434,6 +2898,8 @@ async def _astep_non_streaming_task( if self.prune_tool_calls_from_memory and tool_call_records: self.memory.clean_tool_calls() + self._last_token_limit_user_signature = None + return self._convert_to_chatagent_response( response, tool_call_records, @@ -2530,6 +2996,8 @@ def _get_model_response( if response: break except RateLimitError as e: + if self._is_token_limit_error(e): + raise last_error = e if attempt < self.retry_attempts - 1: delay = min(self.retry_delay * (2**attempt), 60.0) @@ -2547,7 +3015,6 @@ def _get_model_response( except Exception: logger.error( f"Model error: {self.model_backend.model_type}", - exc_info=True, ) raise else: @@ -2594,6 +3061,8 @@ async def _aget_model_response( if response: break except RateLimitError as e: + if self._is_token_limit_error(e): + raise last_error = e if attempt < self.retry_attempts - 1: delay = min(self.retry_delay * (2**attempt), 60.0) @@ -3101,6 +3570,7 @@ def _record_tool_calling( tool_call_id=tool_call_id, ) + self._update_last_tool_call_state(tool_record) return tool_record def _stream( @@ -3662,12 +4132,14 @@ def _execute_tool_from_stream_data( timestamp=base_timestamp + 1e-6, ) - return ToolCallingRecord( + tool_record = ToolCallingRecord( tool_name=function_name, args=args, result=result, tool_call_id=tool_call_id, ) + self._update_last_tool_call_state(tool_record) + return tool_record except Exception as e: error_msg = ( @@ -3689,12 +4161,14 @@ def _execute_tool_from_stream_data( self.update_memory(func_msg, OpenAIBackendRole.FUNCTION) - return ToolCallingRecord( + tool_record = ToolCallingRecord( tool_name=function_name, args=args, result=result, tool_call_id=tool_call_id, ) + self._update_last_tool_call_state(tool_record) + return tool_record else: logger.warning( f"Tool '{function_name}' not found in internal tools" @@ -3785,12 +4259,14 @@ async def _aexecute_tool_from_stream_data( timestamp=base_timestamp + 1e-6, ) - return ToolCallingRecord( + tool_record = ToolCallingRecord( tool_name=function_name, args=args, result=result, tool_call_id=tool_call_id, ) + self._update_last_tool_call_state(tool_record) + return tool_record except Exception as e: error_msg = ( @@ -3812,12 +4288,14 @@ async def _aexecute_tool_from_stream_data( self.update_memory(func_msg, OpenAIBackendRole.FUNCTION) - return ToolCallingRecord( + tool_record = ToolCallingRecord( tool_name=function_name, args=args, result=result, tool_call_id=tool_call_id, ) + self._update_last_tool_call_state(tool_record) + return tool_record else: logger.warning( f"Tool '{function_name}' not found in internal tools" diff --git a/camel/memories/agent_memories.py b/camel/memories/agent_memories.py index 01abee68a1..1789cec469 100644 --- a/camel/memories/agent_memories.py +++ b/camel/memories/agent_memories.py @@ -129,6 +129,16 @@ def clean_tool_calls(self) -> None: # Save the modified records back to storage self._chat_history_block.storage.save(record_dicts) + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes the most recent records from chat history memory.""" + return self._chat_history_block.pop_records(count) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from chat history memory.""" + return self._chat_history_block.remove_records_by_indices(indices) + class VectorDBMemory(AgentMemory): r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries @@ -193,6 +203,20 @@ def clear(self) -> None: r"""Removes all records from the vector database memory.""" self._vectordb_block.clear() + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Rolling back is unsupported for vector database memory.""" + raise NotImplementedError( + "VectorDBMemory does not support removing historical records." + ) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removing by indices is unsupported for vector database memory.""" + raise NotImplementedError( + "VectorDBMemory does not support removing records by indices." + ) + class LongtermAgentMemory(AgentMemory): r"""An implementation of the :obj:`AgentMemory` abstract base class for @@ -277,3 +301,13 @@ def clear(self) -> None: r"""Removes all records from the memory.""" self.chat_history_block.clear() self.vector_db_block.clear() + + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes recent chat history records while leaving vector memory.""" + return self.chat_history_block.pop_records(count) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from chat history.""" + return self.chat_history_block.remove_records_by_indices(indices) diff --git a/camel/memories/base.py b/camel/memories/base.py index f9d4a0ad83..4cf2ccce15 100644 --- a/camel/memories/base.py +++ b/camel/memories/base.py @@ -45,6 +45,32 @@ def write_record(self, record: MemoryRecord) -> None: """ self.write_records([record]) + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes records from the memory and returns the removed records. + + Args: + count (int): Number of records to remove. + + Returns: + List[MemoryRecord]: The records that were removed from the memory + in their original order. + """ + raise NotImplementedError + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from the memory. + + Args: + indices (List[int]): List of indices to remove. Indices should be + valid positions in the current record list. + + Returns: + List[MemoryRecord]: The removed records in their original order. + """ + raise NotImplementedError + @abstractmethod def clear(self) -> None: r"""Clears all messages from the memory.""" diff --git a/camel/memories/blocks/chat_history_block.py b/camel/memories/blocks/chat_history_block.py index 8beaa909e4..1f311f131a 100644 --- a/camel/memories/blocks/chat_history_block.py +++ b/camel/memories/blocks/chat_history_block.py @@ -167,3 +167,118 @@ def write_records(self, records: List[MemoryRecord]) -> None: def clear(self) -> None: r"""Clears all chat messages from the memory.""" self.storage.clear() + + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes the most recent records from the memory. + + Args: + count (int): Number of records to remove from the end of the + conversation history. A value of 0 results in no changes. + + Returns: + List[MemoryRecord]: The removed records in chronological order. + """ + if not isinstance(count, int): + raise TypeError("`count` must be an integer.") + if count < 0: + raise ValueError("`count` must be non-negative.") + if count == 0: + return [] + + record_dicts = self.storage.load() + if not record_dicts: + return [] + + # Preserve initial system/developer instruction if present. + protected_prefix = ( + 1 + if ( + record_dicts + and record_dicts[0]['role_at_backend'] + in { + OpenAIBackendRole.SYSTEM.value, + OpenAIBackendRole.DEVELOPER.value, + } + ) + else 0 + ) + + removable_count = max(len(record_dicts) - protected_prefix, 0) + if removable_count == 0: + return [] + + pop_count = min(count, removable_count) + split_index = len(record_dicts) - pop_count + + popped_dicts = record_dicts[split_index:] + remaining_dicts = record_dicts[:split_index] + + self.storage.clear() + if remaining_dicts: + self.storage.save(remaining_dicts) + + return [MemoryRecord.from_dict(record) for record in popped_dicts] + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from the memory. + + Args: + indices (List[int]): List of indices to remove. Indices are + positions in the current record list (0-based). + System/developer messages at index 0 are protected and will + not be removed. + + Returns: + List[MemoryRecord]: The removed records in their original order. + """ + if not indices: + return [] + + record_dicts = self.storage.load() + if not record_dicts: + return [] + + # Preserve initial system/developer instruction if present. + protected_prefix = ( + 1 + if ( + record_dicts + and record_dicts[0]['role_at_backend'] + in { + OpenAIBackendRole.SYSTEM.value, + OpenAIBackendRole.DEVELOPER.value, + } + ) + else 0 + ) + + # Filter out protected indices and invalid ones + valid_indices = sorted( + { + idx + for idx in indices + if idx >= protected_prefix and idx < len(record_dicts) + } + ) + + if not valid_indices: + return [] + + # Extract records to remove (in original order) + removed_records = [record_dicts[idx] for idx in valid_indices] + + # Build remaining records by excluding removed indices + remaining_dicts = [ + record + for idx, record in enumerate(record_dicts) + if idx not in valid_indices + ] + + # Save back to storage + self.storage.clear() + if remaining_dicts: + self.storage.save(remaining_dicts) + + return [MemoryRecord.from_dict(record) for record in removed_records] diff --git a/camel/memories/context_creators/score_based.py b/camel/memories/context_creators/score_based.py index 6d2b9ea349..6733a38f8e 100644 --- a/camel/memories/context_creators/score_based.py +++ b/camel/memories/context_creators/score_based.py @@ -11,41 +11,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= -from collections import defaultdict -from typing import Dict, List, Optional, Tuple -from pydantic import BaseModel +from typing import List, Optional, Tuple -from camel.logger import get_logger from camel.memories.base import BaseContextCreator from camel.memories.records import ContextRecord -from camel.messages import FunctionCallingMessage, OpenAIMessage +from camel.messages import OpenAIMessage from camel.types.enums import OpenAIBackendRole from camel.utils import BaseTokenCounter -logger = get_logger(__name__) - - -class _ContextUnit(BaseModel): - idx: int - record: ContextRecord - num_tokens: int - class ScoreBasedContextCreator(BaseContextCreator): - r"""A default implementation of context creation strategy, which inherits - from :obj:`BaseContextCreator`. - - This class provides a strategy to generate a conversational context from - a list of chat history records while ensuring the total token count of - the context does not exceed a specified limit. It prunes messages based - on their score if the total token count exceeds the limit. + r"""A context creation strategy that orders records chronologically. Args: - token_counter (BaseTokenCounter): An instance responsible for counting - tokens in a message. - token_limit (int): The maximum number of tokens allowed in the - generated context. + token_counter (BaseTokenCounter): Token counter instance used to + compute the combined token count of the returned messages. + token_limit (int): Retained for API compatibility. No longer used to + filter records. """ def __init__( @@ -66,376 +49,34 @@ def create_context( self, records: List[ContextRecord], ) -> Tuple[List[OpenAIMessage], int]: - r"""Constructs conversation context from chat history while respecting - token limits. - - Key strategies: - 1. System message is always prioritized and preserved - 2. Truncation removes low-score messages first - 3. Final output maintains chronological order and in history memory, - the score of each message decreases according to keep_rate. The - newer the message, the higher the score. - 4. Tool calls and their responses are kept together to maintain - API compatibility - - Args: - records (List[ContextRecord]): List of context records with scores - and timestamps. - - Returns: - Tuple[List[OpenAIMessage], int]: - - Ordered list of OpenAI messages - - Total token count of the final context - - Raises: - RuntimeError: If system message alone exceeds token limit - """ - # ====================== - # 1. System Message Handling - # ====================== - system_unit, regular_units = self._extract_system_message(records) - system_tokens = system_unit.num_tokens if system_unit else 0 + """Returns messages sorted by timestamp and their total token count.""" - # Check early if system message alone exceeds token limit - if system_tokens > self.token_limit: - raise RuntimeError( - f"System message alone exceeds token limit" - f": {system_tokens} > {self.token_limit}", - system_tokens, - ) + system_record: Optional[ContextRecord] = None + remaining_records: List[ContextRecord] = [] - # ====================== - # 2. Deduplication & Initial Processing - # ====================== - seen_uuids = set() - if system_unit: - seen_uuids.add(system_unit.record.memory_record.uuid) - - # Process non-system messages with deduplication - for idx, record in enumerate(records): + for record in records: if ( - record.memory_record.role_at_backend + system_record is None + and record.memory_record.role_at_backend == OpenAIBackendRole.SYSTEM ): + system_record = record continue - if record.memory_record.uuid in seen_uuids: - continue - seen_uuids.add(record.memory_record.uuid) - - token_count = self.token_counter.count_tokens_from_messages( - [record.memory_record.to_openai_message()] - ) - regular_units.append( - _ContextUnit( - idx=idx, - record=record, - num_tokens=token_count, - ) - ) - - # ====================== - # 3. Tool Call Relationship Mapping - # ====================== - tool_call_groups = self._group_tool_calls_and_responses(regular_units) - - # ====================== - # 4. Token Calculation - # ====================== - total_tokens = system_tokens + sum(u.num_tokens for u in regular_units) - - # ====================== - # 5. Early Return if Within Limit - # ====================== - if total_tokens <= self.token_limit: - sorted_units = sorted( - regular_units, key=self._conversation_sort_key - ) - return self._assemble_output(sorted_units, system_unit) - - # ====================== - # 6. Truncation Logic with Tool Call Awareness - # ====================== - remaining_units = self._truncate_with_tool_call_awareness( - regular_units, tool_call_groups, system_tokens - ) - - # Log only after truncation is actually performed so that both - # the original and the final token counts are visible. - tokens_after = system_tokens + sum( - u.num_tokens for u in remaining_units - ) - logger.warning( - "Context truncation performed: " - f"before={total_tokens}, after={tokens_after}, " - f"limit={self.token_limit}" - ) - - # ====================== - # 7. Output Assembly - # ====================== - - # In case system message is the only message in memory when sorted - # units are empty, raise an error - if system_unit and len(remaining_units) == 0 and len(records) > 1: - raise RuntimeError( - "System message and current message exceeds token limit ", - total_tokens, - ) - - # Sort remaining units chronologically - final_units = sorted(remaining_units, key=self._conversation_sort_key) - return self._assemble_output(final_units, system_unit) - - def _group_tool_calls_and_responses( - self, units: List[_ContextUnit] - ) -> Dict[str, List[_ContextUnit]]: - r"""Groups tool calls with their corresponding responses based on - `tool_call_id`. - - This improved logic robustly gathers all messages (assistant requests - and tool responses, including chunks) that share a `tool_call_id`. - - Args: - units (List[_ContextUnit]): List of context units to analyze. - - Returns: - Dict[str, List[_ContextUnit]]: Mapping from `tool_call_id` to a - list of related units. - """ - tool_call_groups: Dict[str, List[_ContextUnit]] = defaultdict(list) - - for unit in units: - # FunctionCallingMessage stores tool_call_id. - message = unit.record.memory_record.message - tool_call_id = getattr(message, 'tool_call_id', None) - - if tool_call_id: - tool_call_groups[tool_call_id].append(unit) - - # Filter out empty or incomplete groups if necessary, - # though defaultdict and getattr handle this gracefully. - return dict(tool_call_groups) - - def _truncate_with_tool_call_awareness( - self, - regular_units: List[_ContextUnit], - tool_call_groups: Dict[str, List[_ContextUnit]], - system_tokens: int, - ) -> List[_ContextUnit]: - r"""Truncates messages while preserving tool call-response pairs. - This method implements a more sophisticated truncation strategy: - 1. It treats tool call groups (request + responses) and standalone - messages as individual items to be included. - 2. It sorts all items by score and greedily adds them to the context. - 3. **Partial Truncation**: If a complete tool group is too large to - fit,it attempts to add the request message and as many of the most - recent response chunks as the token budget allows. - - Args: - regular_units (List[_ContextUnit]): All regular message units. - tool_call_groups (Dict[str, List[_ContextUnit]]): Grouped tool - calls. - system_tokens (int): Tokens used by the system message. - - Returns: - List[_ContextUnit]: A list of units that fit within the token - limit. - """ - - # Create a set for quick lookup of units belonging to any tool call - tool_call_unit_ids = { - unit.record.memory_record.uuid - for group in tool_call_groups.values() - for unit in group - } - - # Separate standalone units from tool call groups - standalone_units = [ - u - for u in regular_units - if u.record.memory_record.uuid not in tool_call_unit_ids - ] - - # Prepare all items (standalone units and groups) for sorting - all_potential_items: List[Dict] = [] - for unit in standalone_units: - all_potential_items.append( - { - "type": "standalone", - "score": unit.record.score, - "timestamp": unit.record.timestamp, - "tokens": unit.num_tokens, - "item": unit, - } - ) - for group in tool_call_groups.values(): - all_potential_items.append( - { - "type": "group", - "score": max(u.record.score for u in group), - "timestamp": max(u.record.timestamp for u in group), - "tokens": sum(u.num_tokens for u in group), - "item": group, - } - ) - - # Sort all potential items by score (high to low), then timestamp - all_potential_items.sort(key=lambda x: (-x["score"], -x["timestamp"])) - - remaining_units: List[_ContextUnit] = [] - current_tokens = system_tokens - - for item_dict in all_potential_items: - item_type = item_dict["type"] - item = item_dict["item"] - item_tokens = item_dict["tokens"] - - if current_tokens + item_tokens <= self.token_limit: - # The whole item (standalone or group) fits, so add it - if item_type == "standalone": - remaining_units.append(item) - else: # item_type == "group" - remaining_units.extend(item) - current_tokens += item_tokens - - elif item_type == "group": - # The group does not fit completely; try partial inclusion. - request_unit: Optional[_ContextUnit] = None - response_units: List[_ContextUnit] = [] - - for unit in item: - # Assistant msg with `args` is the request - if ( - isinstance( - unit.record.memory_record.message, - FunctionCallingMessage, - ) - and unit.record.memory_record.message.args is not None - ): - request_unit = unit - else: - response_units.append(unit) - - # A group must have a request to be considered for inclusion. - if request_unit is None: - continue - - # Check if we can at least fit the request. - if ( - current_tokens + request_unit.num_tokens - <= self.token_limit - ): - units_to_add = [request_unit] - tokens_to_add = request_unit.num_tokens - - # Sort responses by timestamp to add newest chunks first - response_units.sort( - key=lambda u: u.record.timestamp, reverse=True - ) + remaining_records.append(record) - for resp_unit in response_units: - if ( - current_tokens - + tokens_to_add - + resp_unit.num_tokens - <= self.token_limit - ): - units_to_add.append(resp_unit) - tokens_to_add += resp_unit.num_tokens + remaining_records.sort(key=lambda record: record.timestamp) - # A request must be followed by at least one response - if len(units_to_add) > 1: - remaining_units.extend(units_to_add) - current_tokens += tokens_to_add + messages: List[OpenAIMessage] = [] + if system_record is not None: + messages.append(system_record.memory_record.to_openai_message()) - return remaining_units - - def _extract_system_message( - self, records: List[ContextRecord] - ) -> Tuple[Optional[_ContextUnit], List[_ContextUnit]]: - r"""Extracts the system message from records and validates it. - - Args: - records (List[ContextRecord]): List of context records - representing conversation history. - - Returns: - Tuple[Optional[_ContextUnit], List[_ContextUnit]]: containing: - - The system message as a `_ContextUnit`, if valid; otherwise, - `None`. - - An empty list, serving as the initial container for regular - messages. - """ - if not records: - return None, [] - - first_record = records[0] - if ( - first_record.memory_record.role_at_backend - != OpenAIBackendRole.SYSTEM - ): - return None, [] - - message = first_record.memory_record.to_openai_message() - tokens = self.token_counter.count_tokens_from_messages([message]) - system_message_unit = _ContextUnit( - idx=0, - record=first_record, - num_tokens=tokens, + messages.extend( + record.memory_record.to_openai_message() + for record in remaining_records ) - return system_message_unit, [] - - def _conversation_sort_key( - self, unit: _ContextUnit - ) -> Tuple[float, float]: - r"""Defines the sorting key for assembling the final output. - - Sorting priority: - - Primary: Sort by timestamp in ascending order (chronological order). - - Secondary: Sort by score in descending order (higher scores first - when timestamps are equal). - - Args: - unit (_ContextUnit): A `_ContextUnit` representing a conversation - record. - - Returns: - Tuple[float, float]: - - Timestamp for chronological sorting. - - Negative score for descending order sorting. - """ - return (unit.record.timestamp, -unit.record.score) - - def _assemble_output( - self, - context_units: List[_ContextUnit], - system_unit: Optional[_ContextUnit], - ) -> Tuple[List[OpenAIMessage], int]: - r"""Assembles final message list with proper ordering and token count. - - Args: - context_units (List[_ContextUnit]): Sorted list of regular message - units. - system_unit (Optional[_ContextUnit]): System message unit (if - present). - - Returns: - Tuple[List[OpenAIMessage], int]: Tuple of (ordered messages, total - tokens) - """ - messages = [] - total_tokens = 0 - - # Add system message first if present - if system_unit: - messages.append( - system_unit.record.memory_record.to_openai_message() - ) - total_tokens += system_unit.num_tokens - # Add sorted regular messages - for unit in context_units: - messages.append(unit.record.memory_record.to_openai_message()) - total_tokens += unit.num_tokens + if not messages: + return [], 0 + total_tokens = self.token_counter.count_tokens_from_messages(messages) return messages, total_tokens diff --git a/camel/storages/vectordb_storages/oceanbase.py b/camel/storages/vectordb_storages/oceanbase.py index 7cfe4b5c16..c250e87b0b 100644 --- a/camel/storages/vectordb_storages/oceanbase.py +++ b/camel/storages/vectordb_storages/oceanbase.py @@ -121,10 +121,11 @@ def __init__( ) # Get the first index parameter - first_index_param = next(iter(index_params)) - self._client.create_vidx_with_vec_index_param( - table_name=self.table_name, vidx_param=first_index_param - ) + first_index_param = next(iter(index_params), None) + if first_index_param is not None: + self._client.create_vidx_with_vec_index_param( + table_name=self.table_name, vidx_param=first_index_param + ) logger.info(f"Created table {self.table_name} with vector index") else: diff --git a/examples/summarization/handle_token_limit.py b/examples/summarization/handle_token_limit.py new file mode 100644 index 0000000000..9f810de855 --- /dev/null +++ b/examples/summarization/handle_token_limit.py @@ -0,0 +1,68 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +import os + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.toolkits import TerminalToolkit +from camel.types import ModelPlatformType, ModelType + +# Get current script directory +base_dir = os.path.dirname(os.path.abspath(__file__)) +# Define workspace directory for the toolkit +workspace_dir = os.path.join( + os.path.dirname(os.path.dirname(base_dir)), "workspace" +) + +# Define system message +sys_msg = ( + "You are a System Administrator helping with log management tasks. " + "You have access to terminal tools that can help you execute " + "shell commands and search files. " +) + +# Set model config +tools = TerminalToolkit(working_directory=workspace_dir).get_tools() + + +model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_3_5_TURBO, +) + +# Set agent +camel_agent = ChatAgent( + system_message=sys_msg, + model=model, + tools=tools, + summarize_threshold=1, + summary_window_ratio=0.02, +) +camel_agent.reset() + +# Define a user message for creating logs directory +usr_msg = ( + f"Create a 'logs' directory in '{workspace_dir}' and list its contents" +) + +# simulate a long conversation +for _i in range(20): + response = camel_agent.step(usr_msg) + print(camel_agent._summary_token_count) + +# handle a large file +usr_msg = "../../uv.lock,read this file" + +response = camel_agent.step(usr_msg) diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py index ffc8a06da5..2e6f7538eb 100644 --- a/test/agents/test_chat_agent.py +++ b/test/agents/test_chat_agent.py @@ -560,6 +560,21 @@ def test_chat_agent_step_exceed_token_number(step_call_count=3): system_message=system_msg, token_limit=1, ) + + original_get_context = assistant.memory.get_context + + def mock_get_context(): + messages, _ = original_get_context() + # Raise RuntimeError as if context size exceeded limit + raise RuntimeError( + "Context size exceeded", + { + "status": "error", + "message": "The context has exceeded the maximum token limit.", + }, + ) + + assistant.memory.get_context = mock_get_context assistant.model_backend.run = MagicMock( return_value=model_backend_rsp_base ) diff --git a/test/memories/context_creators/test_score_based.py b/test/memories/context_creators/test_score_based.py index a396212d03..4f862f2967 100644 --- a/test/memories/context_creators/test_score_based.py +++ b/test/memories/context_creators/test_score_based.py @@ -70,10 +70,8 @@ def test_score_based_context_creator(): ] expected_output = [ - r.memory_record.to_openai_message() - for r in [ - context_records[1] # Only expect the highest scoring message - ] + record.memory_record.to_openai_message() + for record in sorted(context_records, key=lambda r: r.timestamp) ] output, _ = context_creator.create_context(records=context_records) assert expected_output == output @@ -137,9 +135,16 @@ def test_score_based_context_creator_with_system_message(): score=0.9, ), ] + sorted_records = sorted( + (record for record in context_records[1:]), + key=lambda r: r.timestamp, + ) expected_output = [ - r.memory_record.to_openai_message() - for r in [context_records[0], context_records[2], context_records[3]] + context_records[0].memory_record.to_openai_message(), + *( + record.memory_record.to_openai_message() + for record in sorted_records + ), ] output, _ = context_creator.create_context(records=context_records) assert expected_output == output diff --git a/test/memories/test_chat_history_memory.py b/test/memories/test_chat_history_memory.py index 22daecfe84..3eee644402 100644 --- a/test/memories/test_chat_history_memory.py +++ b/test/memories/test_chat_history_memory.py @@ -91,3 +91,69 @@ def test_chat_history_memory(memory: ChatHistoryMemory): assert output_messages[0] == system_msg.to_openai_system_message() assert output_messages[1] == user_msg.to_openai_user_message() assert output_messages[2] == assistant_msg.to_openai_assistant_message() + + +@pytest.mark.parametrize("memory", ["in-memory", "json"], indirect=True) +def test_chat_history_memory_pop_records(memory: ChatHistoryMemory): + system_msg = BaseMessage( + "system", + role_type=RoleType.DEFAULT, + meta_dict=None, + content="System instructions", + ) + user_msgs = [ + BaseMessage( + "AI user", + role_type=RoleType.USER, + meta_dict=None, + content=f"Message {idx}", + ) + for idx in range(3) + ] + + records = [ + MemoryRecord( + message=system_msg, + role_at_backend=OpenAIBackendRole.SYSTEM, + timestamp=datetime.now().timestamp(), + agent_id="system", + ), + *[ + MemoryRecord( + message=msg, + role_at_backend=OpenAIBackendRole.USER, + timestamp=datetime.now().timestamp(), + agent_id="user", + ) + for msg in user_msgs + ], + ] + + memory.write_records(records) + + popped = memory.pop_records(2) + assert [record.message.content for record in popped] == [ + "Message 1", + "Message 2", + ] + + remaining_messages, _ = memory.get_context() + assert [msg['content'] for msg in remaining_messages] == [ + "System instructions", + "Message 0", + ] + + # Attempting to pop more than available should leave system message intact. + popped = memory.pop_records(5) + assert [record.message.content for record in popped] == ["Message 0"] + + remaining_messages, _ = memory.get_context() + assert [msg['content'] for msg in remaining_messages] == [ + "System instructions", + ] + + # Zero pop should be a no-op. + assert memory.pop_records(0) == [] + + with pytest.raises(ValueError): + memory.pop_records(-1)