diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..d761679 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 78cdf5a..67da091 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .mypy_cache/ .pytest_cache/ __pycache__/ +*.DS_Store .ipynb_checkpoints/ .hypothesis/ *.egg-info/ @@ -9,4 +10,5 @@ build/ .venv/ uv.lock .env -src/benchmarking/agent_chat/logs +src/benchmark/simple_benchmarking/agent_chat/logs +src/benchmark/tool_plan_benchmarking/logs diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..3be8f2c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,65 @@ +# Contributor guide (AGENTS) + +## Project structure + +- `src/` – main package code. + - `src/summarize_algorithms/` – dialogue summarization implementations (e.g. `memory_bank/`, `recsum/`, shared `core/`). + - `src/benchmarking/` – evaluation scripts, metrics, log parsing, and plotting. + - `src/utils/` – small shared helpers (logging/config parsing). + - Entry point: `src/main.py` (also exposed as a script in `pyproject.toml`: `recapkt = "src.main:main"`). +- `tests/` – pytest suite (files follow `test_*.py`). +- `requirements.txt`, `requirements.dev.txt` – runtime/dev dependencies. + +## Build, test, and development commands + +This repo targets **Python >= 3.12** (see `pyproject.toml`). CI uses **uv**. + +- Create env + install deps (recommended): + ```bash + uv venv + uv pip install -r requirements.txt -r requirements.dev.txt + ``` +- Run the example entry point: + ```bash + python -m src.main + # or + uv run recapkt + ``` +- Lint / format (Ruff): + ```bash + ruff check . + ruff format . + ``` +- Type-check (Mypy): + ```bash + uv run mypy + ``` +- Run tests: + ```bash + uv run python -m pytest + ``` +- Tool-metrics benchmarking helper: + ```bash + ./run.sh # runs src/benchmark/tool_plan_benchmarking/run.py + ``` + +## Code style and naming + +- Formatting/linting: Ruff is the source of truth (line length **120**, double quotes). +- Typing: keep functions typed; the project configuration disallows untyped defs in `src/`. +- Naming: + - modules/files: `snake_case.py` + - classes: `CamelCase` + - tests: `tests/test_.py`, test functions `test_()` + +## VCS: commits and pull requests + +- Commit messages follow a lightweight Conventional Commits style seen in history: `feat: ...`, `fix: ...`. +- PRs should: + - describe the change + rationale, + - include how to reproduce/verify (commands or a minimal snippet), + - keep CI green (GitHub Actions runs `ruff check`, `mypy`, `pytest` on PRs). + +## Secrets and local config + +- Don’t commit `.env`. If your change needs new settings, document them and keep defaults safe. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fc96495 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +.PHONY: help run-main run-tool-plan run-tool-metrics run-tool-metrics-sh test + +# Prefer local venv if present, fall back to system python. +PYTHON := $(shell [ -x .venv/bin/python ] && echo .venv/bin/python || (command -v python3 >/dev/null 2>&1 && echo python3 || echo python)) + +# Optional args for some targets: +# make run-tool-metrics ARG=base_recsum +ARG ?= base_recsum + +help: + @echo "Available targets:" + @echo " make run-main" + @echo " make run-tool-plan" + @echo " make run-tool-metrics ARG=" + @echo " make run-tool-metrics-sh ARG=" + @echo " make test" + +run-main: + $(PYTHON) -m src.main + +run-tool-plan: + cd src/benchmark/tool_plan_benchmarking && $(PYTHON) -m run.py + +run-tool-metrics: + $(MAKE) run-tool-plan + +run-tool-metrics-sh: + ./run.sh $(ARG) + +test: + $(PYTHON) -m pytest -q diff --git a/pyproject.toml b/pyproject.toml index a443010..678c6cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,32 @@ authors = [ { name = "Mikhail Kharlamov" } ] readme = "README.md" -requires-python = ">=3.8" -dependencies = [] +requires-python = ">=3.12" +dependencies = [ + "colorlog~=6.10.1", + "dataclasses-json~=0.6.7", + "datasets==4.0.0", + "faiss-cpu==1.11.0", + "jinja2==3.1.6", + "langchain>=1.1.0", + "langchain-community>=0.4.1", + "langchain-core>=1.1.0", + "langchain-ollama>=1.0.1", + "langchain-openai>=1.0.0", + "langgraph>=1.0.0", + "load-dotenv>=0.1.0", + "matplotlib~=3.10.7", + "numpy>=1.26.2", + "openai~=1.109.1", + "pandas~=2.3.3", + "pydantic~=2.11.9", + "pytest>=9.0.2", + "python-dotenv~=1.1.1", + "scikit-learn==1.5.2", + "seaborn~=0.13.2", + "tiktoken==0.9.0", + "transformers>=4.57.6", +] [project.scripts] recapkt = "src.main:main" @@ -48,5 +72,9 @@ warn_no_return = "False" no_implicit_optional = "False" [tool.pytest.ini_options] +pythonpath = ["src"] testpaths = ["tests"] -addopts = ["--color=yes", "-s"] \ No newline at end of file +addopts = ["--color=yes", "-s"] + +[dependency-groups] +dev = [] diff --git a/requirements.txt b/requirements.txt index 5059335..457c6ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,26 @@ -langchain-core>=0.3.72,<1.0.0 -langchain-openai==0.3.28 -langgraph==0.5.3 -langchain>=0.3.27,<0.4.0 +langchain-openai>=1.0.0 +langchain-core>=1.1.0,<2.0.0 +langchain-openai>=0.3.28 +langchain-ollama>=1.0.1 +langgraph>=1.0.0 +langchain>=1.1.0,<2.0.0 +langchain-community>=0.4.1 + tiktoken==0.9.0 +transformers datasets==4.0.0 numpy>=1.26.2 scikit-learn==1.5.2 faiss-cpu==1.11.0 -langchain-community~=0.3.31 pydantic~=2.11.9 -pytest~=8.3.4 dataclasses-json~=0.6.7 -openai~=1.109.1 \ No newline at end of file +openai~=1.109.1 +jinja2==3.1.6 + +python-dotenv~=1.1.1 + +colorlog~=6.10.1 +seaborn~=0.13.2 +pandas~=2.3.3 +matplotlib~=3.10.7 +load_dotenv \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..4ea5c35 --- /dev/null +++ b/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")" && pwd)" + +source "$ROOT_DIR/.venv/bin/activate" +export PYTHONPATH="$ROOT_DIR" + +cd "$ROOT_DIR/src/benchmark/tool_plan_benchmarking" +python run.py "${1:-}" diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..759b496 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/benchmarking/__init__.py b/src/algorithms/__init__.py similarity index 100% rename from src/benchmarking/__init__.py rename to src/algorithms/__init__.py diff --git a/src/algorithms/dialogue.py b/src/algorithms/dialogue.py new file mode 100644 index 0000000..804065a --- /dev/null +++ b/src/algorithms/dialogue.py @@ -0,0 +1,22 @@ +from typing import Any, Protocol + +from src.algorithms.summarize_algorithms.core.models import DialogueState, Session + + +class Dialogue(Protocol): + """ + Minimal public interface for a dialogue system used throughout benchmark. + + Any implementation must expose a `system_name` and provide `process_dialogue()` returning a `DialogueState`. + """ + + system_name: str + + def process_dialogue( + self, + sessions: list[Session], + system_prompt: str, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None + ) -> DialogueState: + ... diff --git a/src/benchmarking/agent_chat/__init__.py b/src/algorithms/simple_algorithms/__init__.py similarity index 100% rename from src/benchmarking/agent_chat/__init__.py rename to src/algorithms/simple_algorithms/__init__.py diff --git a/src/algorithms/simple_algorithms/dialog_short_tools.py b/src/algorithms/simple_algorithms/dialog_short_tools.py new file mode 100644 index 0000000..e9edc71 --- /dev/null +++ b/src/algorithms/simple_algorithms/dialog_short_tools.py @@ -0,0 +1,31 @@ +from typing import override + +from langchain_core.messages import BaseMessage, ToolMessage + +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.models import Session + + +class DialogueWithShortTools(DialogueBaseline): + """ + Baseline variant that shortens tool messages. + + Keeps tool call structure in the history but clears `ToolMessage.content` to reduce context length. + """ + + @override + @staticmethod + def _compress(sessions: list[Session]) -> list[BaseMessage]: + """ + Compress sessions by clearing tool message contents. + + :param sessions: past sessions. + :return: list[BaseMessage]: flattened history with shortened tool messages. + """ + messages: list[BaseMessage] = DialogueBaseline._get_context(sessions) + + for message in messages: + if isinstance(message, ToolMessage): + message.content = "" + + return messages diff --git a/src/algorithms/simple_algorithms/dialog_with_weights.py b/src/algorithms/simple_algorithms/dialog_with_weights.py new file mode 100644 index 0000000..9ee90e8 --- /dev/null +++ b/src/algorithms/simple_algorithms/dialog_with_weights.py @@ -0,0 +1,49 @@ +from decimal import Decimal +from math import ceil +from typing import override + +from langchain_core.messages import BaseMessage, HumanMessage + +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.models import Session + + +class DialogueWithWeights(DialogueBaseline): + """ + Baseline variant that compresses history by truncating message contents with a positional weight. + + Messages closer to the center of the conversation get truncated more aggressively (triangle-shaped coefficient). + Human messages are preserved. + """ + + @override + @staticmethod + def _compress(sessions: list[Session]) -> list[BaseMessage]: + """ + Compress sessions by truncating non-human messages based on their position. + + :param sessions: past sessions. + :return: list[BaseMessage]: flattened history with weighted truncation applied. + """ + messages: list[BaseMessage] = DialogueBaseline._get_context(sessions) + cropped_messages: list[BaseMessage] = [] + + mid: int = (len(messages) - 1) // 2 + step: Decimal = Decimal(1) / Decimal(mid) + coefficient: Decimal = Decimal(1) + + for i in range(len(messages)): + if coefficient > 0 and i != 0: + coefficient -= step + else: + coefficient += step + + message = messages[i] + if isinstance(message, HumanMessage): + cropped_messages.append(message) + continue + + message.content = message.content[:ceil(len(message.content) * coefficient)] + cropped_messages.append(message) + + return cropped_messages diff --git a/src/algorithms/simple_algorithms/dialogue_baseline.py b/src/algorithms/simple_algorithms/dialogue_baseline.py new file mode 100644 index 0000000..3031a58 --- /dev/null +++ b/src/algorithms/simple_algorithms/dialogue_baseline.py @@ -0,0 +1,262 @@ +import json +import logging +import os + +from typing import Any + +import tiktoken + +from dotenv import load_dotenv +from langchain_community.callbacks import get_openai_callback +from langchain_core.exceptions import OutputParserException +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + BaseMessage, + SystemMessage, + ToolMessage, + trim_messages, AIMessage, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable +from langchain_ollama.chat_models import ChatOllama +from langchain_openai import ChatOpenAI +from pydantic import SecretStr +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.models import ( + DialogueState, + LocalModels, + OpenAIModels, + Session, +) +from src.benchmark.logger.baseline_logger import BaselineLogger +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import TOOLS +from src.utils.system_prompt_builder import MemorySections, SystemPromptBuilder + + +class DialogueBaseline(Dialogue): + """ + Baseline dialogue system that answers using the full (compressed) conversation context. + + This implementation does not build or retrieve long-term memory. It simply converts all provided `Session`s into + a single message history, crops it to a token budget, and calls an LLM. + """ + + def __init__(self, system_name: str, llm: BaseChatModel | None = None, is_local: bool = False) -> None: + load_dotenv() + + self.system_name = system_name + + self._initialize_model(llm, is_local) + + self._prompt_builder = SystemPromptBuilder() + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_cost = 0.0 + + self.baseline_logger = BaselineLogger() + + def _initialize_model(self, llm: BaseChatModel | None = None, is_local: bool = False) -> None: + """ + Initialize the underlying chat model. + + :param llm: optional externally constructed model instance. + :param is_local: if True, uses an Ollama model; otherwise uses OpenAI. + :return: None + """ + if is_local: + self.llm = ChatOllama( + model=LocalModels.QWEN_2_5_14_B.value, + temperature=0.7, + keep_alive="1h" + ) + return + + load_dotenv() + + api_key: str | None = os.getenv("OPENAI_API_KEY") + if api_key is not None: + self.llm = llm or ChatOpenAI( + model=OpenAIModels.GPT_4_O_MINI.value, + api_key=SecretStr(api_key) + ) + else: + raise ValueError("OPENAI_API_KEY environment variable is not loaded") + + def _build_chain( + self, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + ) -> Runnable: + """Build a runnable used for response generation. + + Mirrors `ResponseGenerator._build_chain()` behavior: + - `process_dialogue()` assembles the full prompt as `list[BaseMessage]` + (unified SystemMessage + conversation history) + - therefore the chain must accept `list[BaseMessage]` directly (no prompt variables) + """ + if structure and not tools: + return self.llm.with_structured_output(structure) + + if tools and not structure: + return self.llm.bind_tools(tools) + + if tools and structure: + tools = [DialogueBaseline._get_return_action_plan(structure), *tools] + return self.llm.bind_tools(tools) + + return self.llm | StrOutputParser() + + @staticmethod + def _get_return_action_plan(structure: dict[str, Any]) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": "return_action_plan", + "description": "Return the final JSON action plan strictly matching the schema.", + "parameters": structure, + }, + } + + @staticmethod + def _get_context(sessions: list[Session]) -> list[BaseMessage]: + context_messages: list[BaseMessage] = [] + for session in sessions: + context_messages.extend(session.to_langchain_messages()) + return context_messages + + @staticmethod + def __count_tokens(text: str) -> int: + encoding = tiktoken.get_encoding("o200k_base") + tokens = encoding.encode(text) + return len(tokens) + + def _crop(self, messages: list[Any], max_tokens: int = 80_000) -> list[BaseMessage]: + """ + Trim a list of messages to fit into a token budget. + + This preserves an initial `SystemMessage` (if present) and ensures the first message after trimming is not a + `ToolMessage` (some models/tooling can break when history starts with tool output). + + :param messages: message list to trim. + :param max_tokens: token budget. + :return: list[BaseMessage]: trimmed messages. + """ + system_msg = None + if messages and isinstance(messages[0], SystemMessage): + system_msg = messages[0] + messages_to_trim = messages[1:] + else: + messages_to_trim = messages + + # Log token counts before trimming + total_tokens_before_crop = self.llm.get_num_tokens_from_messages(messages_to_trim) + logging.info(f"Total tokens before crop: {total_tokens_before_crop}") + + trimmed_messages = trim_messages( + messages_to_trim, + token_counter=self.llm, + max_tokens=max_tokens, + strategy="last", + include_system=True, + allow_partial=False, + ) + + while trimmed_messages and isinstance(trimmed_messages[0], ToolMessage): + trimmed_messages.pop(0) + + if system_msg: + # Log token count after trimming + total_tokens_after_crop = self.llm.get_num_tokens_from_messages(trimmed_messages) + logging.info(f"Total tokens after crop (without system message): {total_tokens_after_crop}") + return [system_msg] + trimmed_messages + return trimmed_messages + + @staticmethod + def _compress(sessions: list[Session]) -> list[BaseMessage]: + """ + Convert sessions into a single message history. + + Subclasses override this to apply different context compression strategies. + + :param sessions: past sessions. + :return: list[BaseMessage]: flattened message history. + """ + return DialogueBaseline._get_context(sessions) + + def process_dialogue( + self, + sessions: list[Session], + system_prompt: str, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + ) -> DialogueState: + """ + Generate a response using the baseline approach (no explicit memory). + + :param sessions: past user/assistant/tool interactions. + :param system_prompt: system prompt (used as the system message). + :param structure: optional schema for structured output. + :param tools: optional tools/functions specs for tool calling. + :return: DialogueState: contains the prepared history and model response. + """ + compressed_sessions: list[BaseMessage] = type(self)._compress(sessions) + # Log tokens before crop + total_tokens_before_crop = self.llm.get_num_tokens_from_messages(compressed_sessions) + logging.info(f"Total tokens before crop (compressed sessions): {total_tokens_before_crop}") + + context: list[BaseMessage] = self._crop(compressed_sessions) + + system_instruction = self._prompt_builder.build( + schema=structure, + tools=TOOLS, + memory=MemorySections(), + memory_mode="baseline", + ) + + # Log tokens in system message + system_message_tokens = self.llm.get_num_tokens_from_messages([SystemMessage(content=system_instruction)]) + logging.info(f"Tokens in system message: {system_message_tokens}") + + # Log tokens in context after crop + total_tokens_after_crop = self.llm.get_num_tokens_from_messages(context) + logging.info(f"Total tokens in context after crop: {total_tokens_after_crop}") + + chain = self._build_chain(structure, tools) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type(OutputParserException), + reraise=True + ) + def invoke_with_retry(full_history: list[BaseMessage]) -> Any: + logging.info("Attempting to invoke chain...") + return chain.invoke(full_history) + + # Ensure the unified SystemMessage is the first message in the flow. + context_with_system: list[BaseMessage] = self._crop( + [SystemMessage(content=system_instruction), *context] + ) + + with get_openai_callback() as cb: + result = invoke_with_retry(context_with_system) + + self.prompt_tokens += cb.prompt_tokens + self.completion_tokens += cb.completion_tokens + self.total_cost += cb.total_cost + + return DialogueState( + dialogue_sessions=sessions, + prepared_messages=context_with_system, + query=system_prompt, + _response=result, + code_memory_storage=None, + tool_memory_storage=None + ) diff --git a/src/benchmarking/tool_metrics/__init__.py b/src/algorithms/summarize_algorithms/__init__.py similarity index 100% rename from src/benchmarking/tool_metrics/__init__.py rename to src/algorithms/summarize_algorithms/__init__.py diff --git a/src/summarize_algorithms/__init__.py b/src/algorithms/summarize_algorithms/core/__init__.py similarity index 100% rename from src/summarize_algorithms/__init__.py rename to src/algorithms/summarize_algorithms/core/__init__.py diff --git a/src/algorithms/summarize_algorithms/core/base_dialogue_system.py b/src/algorithms/summarize_algorithms/core/base_dialogue_system.py new file mode 100644 index 0000000..c4bf07c --- /dev/null +++ b/src/algorithms/summarize_algorithms/core/base_dialogue_system.py @@ -0,0 +1,209 @@ +import functools +import logging +import os + +from abc import ABC, abstractmethod +from typing import Any + +from dotenv import load_dotenv +from langchain_community.callbacks import get_openai_callback +from langchain_core.embeddings import Embeddings +from langchain_core.exceptions import OutputParserException +from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import PromptTemplate +from langchain_ollama.chat_models import ChatOllama +from langchain_openai import ChatOpenAI +from langgraph.constants import END +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import SecretStr +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.graph_nodes import ( + UpdateState, + generate_response_node, + should_continue_memory_update, + update_memory_node, +) +from src.algorithms.summarize_algorithms.core.models import ( + DialogueState, + LocalModels, + OpenAIModels, + Session, + WorkflowNode, +) +from src.algorithms.summarize_algorithms.core.prompts import RESPONSE_GENERATION_PROMPT +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.benchmark.logger.memory_logger import MemoryLogger + + +class BaseDialogueSystem(ABC, Dialogue): + _MAX_PROMPT_TOKENS = 80_000 + + """ + Shared LangGraph-based implementation for dialogue systems in this repository. + + The pipeline is built as a graph with two main stages: + 1) update memory (via a concrete `BaseSummarizer` implementation) + 2) generate the final response (via `ResponseGenerator`, optionally with tools/structured output) + + Subclasses plug in the summarizer and the initial `DialogueState`. + """ + + def __init__( + self, + llm: BaseChatModel | None = None, + embed_code: bool = False, + embed_tool: bool = False, + embed_model: Embeddings | None = None, + max_session_id: int = 3, + system_name: str | None = None, + is_local: bool = False, + ) -> None: + self.system_name = system_name or self.__class__.__name__ + + self._initialize_model(llm, is_local) + + self.summarizer = self._build_summarizer() + self.graph = self._build_graph() + self.state: DialogueState | None = None + self.embed_code = embed_code + self.embed_tool = embed_tool + self.embed_model = embed_model + self.max_session_id = max_session_id + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_cost = 0.0 + + self.memory_logger = MemoryLogger() + self.iteration = 0 + + @abstractmethod + def _build_summarizer(self) -> Any: + pass + + @staticmethod + def _get_response_prompt_template() -> PromptTemplate: + return RESPONSE_GENERATION_PROMPT + + @abstractmethod + def _get_initial_state(self, sessions: list[Session], last_session: Session, query: str) -> DialogueState: + pass + + @property + @abstractmethod + def _get_dialogue_state_class(self) -> type[DialogueState]: + pass + + def _initialize_model(self, llm: BaseChatModel | None = None, is_local: bool = False) -> None: + load_dotenv() + + api_key: str | None = os.getenv("OPENAI_API_KEY") + if api_key is not None: + self.memory_llm = ChatOpenAI( + model=OpenAIModels.GPT_5_MINI.value, + api_key=SecretStr(api_key) + ) + else: + raise ValueError("OPENAI_API_KEY environment variable is not loaded") + + if is_local: + self.llm = ChatOllama( + model=LocalModels.QWEN_2_5_14_B.value, + temperature=0, + keep_alive="1h" + ) + else: + self.llm = llm or ChatOpenAI( + model=OpenAIModels.GPT_4_O_MINI.value, + api_key=SecretStr(api_key) + ) + + + def _build_graph( + self, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + ) -> CompiledStateGraph: + self.response_generator = ResponseGenerator( + self.llm, + structure=structure, + tools=tools, + max_prompt_tokens=self._MAX_PROMPT_TOKENS, + ) + + workflow = StateGraph(self._get_dialogue_state_class) + + workflow.add_node( + WorkflowNode.UPDATE_MEMORY.value, + functools.partial(update_memory_node, self.summarizer), + ) + workflow.add_node( + WorkflowNode.GENERATE_RESPONSE.value, + functools.partial(generate_response_node, self.response_generator), + ) + + workflow.set_entry_point(WorkflowNode.UPDATE_MEMORY.value) + + workflow.add_conditional_edges( + WorkflowNode.UPDATE_MEMORY.value, + should_continue_memory_update, + { + UpdateState.CONTINUE_UPDATE.value: WorkflowNode.UPDATE_MEMORY.value, + UpdateState.FINISH_UPDATE.value: WorkflowNode.GENERATE_RESPONSE.value, + }, + ) + + workflow.add_edge(WorkflowNode.GENERATE_RESPONSE.value, END) + + return workflow.compile() + + def process_dialogue( + self, + sessions: list[Session], + system_prompt: str, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None + ) -> DialogueState: + """ + Run the dialogue workflow and return the final `DialogueState`. + + :param sessions: past user/assistant/tool interactions (last element is treated as the current session). + :param system_prompt: system prompt template used during response generation. + :param structure: optional JSON schema for structured model output. + :param tools: optional tools/functions specs for tool calling. + :return: DialogueState: state populated with updated memory and the generated response. + """ + graph = self._build_graph(structure, tools) + initial_state = self._get_initial_state(sessions, sessions[-1], system_prompt) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type(OutputParserException), + reraise=True, + ) + def invoke_with_retry(state: DialogueState) -> dict[str, Any]: + logging.info("Attempting to invoke graph...") + result = graph.invoke(state) + if not isinstance(result, dict): + raise TypeError(f"Graph invocation returned unexpected type: {type(result)}") + return result + + with get_openai_callback() as cb: + result_state = invoke_with_retry(initial_state) + self.state = self._get_dialogue_state_class(**result_state) + + self.prompt_tokens += cb.prompt_tokens + self.completion_tokens += cb.completion_tokens + self.total_cost += cb.total_cost + + return self.state if self.state is not None else initial_state diff --git a/src/algorithms/summarize_algorithms/core/base_summarizer.py b/src/algorithms/summarize_algorithms/core/base_summarizer.py new file mode 100644 index 0000000..cafc300 --- /dev/null +++ b/src/algorithms/summarize_algorithms/core/base_summarizer.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Any + +from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import Runnable + + +class BaseSummarizer(ABC): + """ + Base class for summarizers used to update long-term dialogue memory. + + A summarizer wraps an LLM + prompt into a reusable LangChain runnable (`self.chain`). Concrete implementations + define how to build the chain and which inputs they accept in `summarize()`. + """ + + def __init__(self, llm: BaseChatModel, prompt: PromptTemplate) -> None: + self.llm = llm + self.prompt = prompt + self.chain = self._build_chain() + + @abstractmethod + def _build_chain(self) -> Runnable[dict[str, Any], Any]: + pass + + @abstractmethod + def summarize(self, *args: Any, **kwargs: Any) -> Any: + """ + Summarize/update memory. + + Concrete implementations define the accepted inputs (e.g. previous memory + dialogue context) and the + returned memory representation. + + :param args: positional arguments required by concrete summarizers. + :param kwargs: keyword arguments required by concrete summarizers. + :return: Any: summarizer-specific memory output. + """ + pass diff --git a/src/summarize_algorithms/core/graph_nodes.py b/src/algorithms/summarize_algorithms/core/graph_nodes.py similarity index 55% rename from src/summarize_algorithms/core/graph_nodes.py rename to src/algorithms/summarize_algorithms/core/graph_nodes.py index 58ab26a..6041e68 100644 --- a/src/summarize_algorithms/core/graph_nodes.py +++ b/src/algorithms/summarize_algorithms/core/graph_nodes.py @@ -1,19 +1,23 @@ -from src.summarize_algorithms.core.base_summarizer import BaseSummarizer -from src.summarize_algorithms.core.models import ( +from src.algorithms.summarize_algorithms.core.base_summarizer import BaseSummarizer +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, DialogueState, + MemoryBankDialogueState, + MemoryDialogueState, RecsumDialogueState, + ResponseContext, + Session, UpdateState, ) -from src.summarize_algorithms.core.response_generator import ResponseGenerator +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.utils.system_prompt_builder import MemorySections def update_memory_node( - summarizer_instance: BaseSummarizer, state: DialogueState -) -> DialogueState: - from src.summarize_algorithms.memory_bank.dialogue_system import ( - MemoryBankDialogueState, - ) - + summarizer_instance: BaseSummarizer, state: MemoryDialogueState +) -> MemoryDialogueState: current_dialogue_session = state.dialogue_sessions[state.current_session_index] if state.code_memory_storage is not None: @@ -51,39 +55,42 @@ def update_memory_node( def generate_response_node( - response_generator_instance: ResponseGenerator, state: DialogueState -) -> DialogueState: - from src.summarize_algorithms.memory_bank.dialogue_system import ( - MemoryBankDialogueState, - ) - + response_generator_instance: ResponseGenerator, + state: MemoryDialogueState +) -> MemoryDialogueState: if isinstance(state, RecsumDialogueState): - dialogue_memory = state.latest_memory + text_memory = state.latest_memory elif isinstance(state, MemoryBankDialogueState): - dialogue_memory = "\n".join(state.text_memory_storage.find_similar(state.query)) + text_memory_blocks = state.text_memory_storage.find_similar(state.query) + text_memory = str(Session(text_memory_blocks)) else: raise TypeError( f"Unsupported status type for update_memory_node: {type(state)}" ) + code_memory: list[BaseBlock] = [] if state.code_memory_storage is not None: - code_memory = "\n".join(state.code_memory_storage.find_similar(state.query)) - else: - code_memory = "Code Memory is missing" + code_memory = state.code_memory_storage.find_similar(state.query) + tool_memory: list[BaseBlock] = [] if state.tool_memory_storage is not None: - tool_memory = "\n".join(state.tool_memory_storage.find_similar(state.query)) - else: - tool_memory = "Tool Memory is missing" + tool_memory = state.tool_memory_storage.find_similar(state.query) + + memory_sections = MemorySections( + recap=text_memory if isinstance(state, RecsumDialogueState) else None, + memory_bank=text_memory if isinstance(state, MemoryBankDialogueState) else None, + code_knowledge=str(Session(code_memory)) if len(code_memory) > 0 else None, + tool_memory=str(Session(tool_memory)) if len(tool_memory) > 0 else None, + ) - final_response = response_generator_instance.generate_response( - dialogue_memory=dialogue_memory, - code_memory=code_memory, - tool_memory=tool_memory, - query=state.query, + final_response: ResponseContext = response_generator_instance.generate_response( + last_session=state.last_session, + user_query=state.query, + memory=memory_sections ) - state._response = final_response + state._response = final_response.response + state.prepared_messages = final_response.prepared_history return state diff --git a/src/summarize_algorithms/core/__init__.py b/src/algorithms/summarize_algorithms/core/memory_storage/__init__.py similarity index 100% rename from src/summarize_algorithms/core/__init__.py rename to src/algorithms/summarize_algorithms/core/memory_storage/__init__.py diff --git a/src/summarize_algorithms/core/memory_storage.py b/src/algorithms/summarize_algorithms/core/memory_storage/memory_storage.py similarity index 62% rename from src/summarize_algorithms/core/memory_storage.py rename to src/algorithms/summarize_algorithms/core/memory_storage/memory_storage.py index f572fc3..2c3ae81 100644 --- a/src/summarize_algorithms/core/memory_storage.py +++ b/src/algorithms/summarize_algorithms/core/memory_storage/memory_storage.py @@ -1,8 +1,8 @@ import math import os -from dataclasses import dataclass -from typing import Any, Iterable, Optional +from collections.abc import Iterable +from typing import Any import faiss import numpy as np @@ -12,20 +12,29 @@ from langchain_openai import OpenAIEmbeddings from pydantic import SecretStr -from src.summarize_algorithms.core.models import BaseBlock, CodeBlock +from src.algorithms.summarize_algorithms.core.memory_storage.models import ( + CodeMemoryFragment, + MemoryFragment, + ToolMemoryFragment, +) +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + CodeBlock, + ToolCallBlock, +) -@dataclass -class MemoryFragment: - embed_content: str - content: str - session_id: int +class MemoryStorage: + """ + Vector store for dialogue memory fragments (text/code/tool). + Stores embeddings in a FAISS inner-product index and keeps the original content in a parallel list. + Used by dialogue systems to retrieve top-k relevant past fragments for a given query. + """ -class MemoryStorage: def __init__( self, - embeddings: Optional[Embeddings] = None, + embeddings: Embeddings | None = None, max_session_id: int = 3, ) -> None: load_dotenv() @@ -58,6 +67,13 @@ def _normalize_vectors(vectors: np.ndarray) -> np.ndarray: return vectors / norms def add_memory(self, memories: Iterable[BaseBlock], session_id: int) -> None: + """ + Embed and index new memory blocks for a given session. + + :param memories: blocks to store (text/code/tool). Empty iterables are ignored. + :param session_id: index of the session these blocks came from (used for time-decay weighting). + :return: None + """ if not memories: return @@ -80,18 +96,27 @@ def add_memory(self, memories: Iterable[BaseBlock], session_id: int) -> None: self.index.add(weighted_embeddings) for memory in memories: - if isinstance(memory, CodeBlock): - content = memory.code + if isinstance(memory, ToolCallBlock): + self.memory_list.append( + ToolMemoryFragment.from_block(memory, session_id=session_id) + ) + elif isinstance(memory, CodeBlock): + self.memory_list.append( + CodeMemoryFragment.from_block(memory, session_id=session_id) + ) else: - content = memory.content - - self.memory_list.append( - MemoryFragment( - embed_content=memory.content, content=content, session_id=session_id + self.memory_list.append( + MemoryFragment.from_block(memory, session_id=session_id) ) - ) - def find_similar(self, query: str, top_k: int = 5) -> list[str]: + def find_similar(self, query: str, top_k: int = 5) -> list[BaseBlock]: + """ + Return top-k stored blocks most similar to the query (cosine/IP on normalized vectors). + + :param query: query string to search for. + :param top_k: maximum number of results to return. + :return: list[BaseBlock]: retrieved blocks ordered by similarity. + """ if self.index is None or len(self.memory_list) == 0: return [] @@ -104,9 +129,9 @@ def find_similar(self, query: str, top_k: int = 5) -> list[str]: normalized_query, min(top_k, len(self.memory_list)) )[1] - results = [] + results: list[BaseBlock] = [] for idx in indices[0]: - results.append(self.memory_list[idx].content) + results.append(self.memory_list[idx].to_block()) return results @@ -114,6 +139,12 @@ def get_memory_count(self) -> int: return len(self.memory_list) def get_session_memory(self, session_id: int) -> list[str]: + """ + Get raw stored contents for a specific session id. + + :param session_id: session index. + :return: list[str]: stored fragment contents. + """ if session_id < 0 or session_id >= self.max_session_id: raise ValueError( f"Session ID must be between 0 and {self.max_session_id - 1}." @@ -126,6 +157,11 @@ def get_session_memory(self, session_id: int) -> list[str]: ] def to_dict(self) -> dict[str, Any]: + """ + Serialize the storage metadata (memory list + FAISS index info) for logging/debugging. + + :return: dict[str, Any]: JSON-serializable snapshot of the storage. + """ return { "memory_list": [ { diff --git a/src/algorithms/summarize_algorithms/core/memory_storage/models.py b/src/algorithms/summarize_algorithms/core/memory_storage/models.py new file mode 100644 index 0000000..5f5d4cd --- /dev/null +++ b/src/algorithms/summarize_algorithms/core/memory_storage/models.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass +from typing import override + +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + CodeBlock, + ToolCallBlock, +) + + +@dataclass +class MemoryFragment: + """Serializable representation of a remembered block used by `MemoryStorage`.""" + embed_content: str + content: str + role: str + session_id: int + + @classmethod + def from_block(cls, block: BaseBlock, session_id: int) -> "MemoryFragment": + """ + Create a fragment from a session block. + + :param block: source message block. + :param session_id: session index. + :return: MemoryFragment: created fragment. + """ + return cls( + embed_content=block.content, + content=block.content, + session_id=session_id, + role=block.role, + ) + + def to_block(self) -> BaseBlock: + """ + Convert the fragment back to a `BaseBlock` (used when retrieving from the vector store). + + :return: BaseBlock: restored block. + """ + return BaseBlock( + role="assistant", + content=self.content, + ) + + +@dataclass +class ToolMemoryFragment(MemoryFragment): + """Specialized fragment for tool calls/responses (keeps tool metadata).""" + id: str + name: str + arguments: str + response: str + + @override + @classmethod + def from_block(cls, block: ToolCallBlock, session_id: int): + return cls( + embed_content=block.content, + content=block.content, + session_id=session_id, + id=block.id, + name=block.name, + arguments=block.arguments, + response=block.response, + role=block.role, + ) + + @override + def to_block(self): + return ToolCallBlock( + role=self.role, + content=self.content, + id=self.id, + name=self.name, + arguments=self.arguments, + response=self.response, + ) + + +@dataclass +class CodeMemoryFragment(MemoryFragment): + """Specialized fragment for code blocks (embeds `code`, not `content`).""" + code: str + + @override + @classmethod + def from_block(cls, block: CodeBlock, session_id: int): + return cls( + embed_content=block.code, + content=block.code, + session_id=session_id, + code=block.code, + role=block.role, + ) + + @override + def to_block(self): + return CodeBlock( + role=self.role, + content="", + code=self.code, + ) diff --git a/src/algorithms/summarize_algorithms/core/models.py b/src/algorithms/summarize_algorithms/core/models.py new file mode 100644 index 0000000..f74c1a9 --- /dev/null +++ b/src/algorithms/summarize_algorithms/core/models.py @@ -0,0 +1,259 @@ +import json +import logging +from collections.abc import Iterator +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from dataclasses_json import dataclass_json, DataClassJsonMixin +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + + +class OpenAIModels(Enum): + """Names of OpenAI chat models used in this repository.""" + GPT_3_5_TURBO = "gpt-3.5-turbo" + GPT_4_1_MINI = "gpt-4.1-mini" + GPT_4_O = "gpt-4o" + GPT_4_1 = "gpt-4.1" + GPT_5_NANO = "gpt-5-nano" + GPT_5_MINI = "gpt-5-mini" + GPT_4_O_MINI = "gpt-4o-mini" + + +class LocalModels(Enum): + """Names of local (Ollama) models used in this repository.""" + GEMMA_2_9_B = "gemma2:9b" + QWEN_2_5_14_B = "qwen2.5:14b" + + +@dataclass +class BaseBlock(DataClassJsonMixin): + """A single message block in a `Session` (role + textual content).""" + role: str + content: str + + def __str__(self) -> str: + return f"{self.role}: {self.content}" + + +@dataclass +class CodeBlock(BaseBlock): + """A message block that contains code (stored in `code`).""" + code: str + + +@dataclass +class ToolCallBlock(BaseBlock): + """Represents a tool/function call and its response within a session.""" + id: str + name: str + arguments: str + response: str + + +class Session: + """ + Ordered list of dialogue blocks (user/assistant/code/tool) with helpers for LangChain conversion. + + This is the primary interchange format between dataset loaders, dialogue systems, and benchmark. + """ + + def __init__(self, messages: list[BaseBlock]) -> None: + self.messages = messages + + def __len__(self) -> int: + return len(self.messages) + + def __str__(self) -> str: + if len(self.messages) == 0: + return "missing" + + result_messages = [] + for msg in self.messages: + if isinstance(msg, CodeBlock): + result_messages.append(f"{msg.role}: {msg.code}") + if isinstance(msg, ToolCallBlock): + result_messages.append( + f"Tool Call [{msg.id}]: {msg.name} - {msg.arguments} -> {msg.response}: {msg.content}" + ) + else: + result_messages.append(f"{msg.role}: {msg.content}") + return "\n".join(result_messages) + + def __getitem__(self, index: int) -> BaseBlock: + return self.messages[index] + + def __iter__(self) -> Iterator[BaseBlock]: + return iter(self.messages) + + def to_dict(self) -> dict[str, Any]: + result_messages = [] + for msg in self.messages: + if isinstance(msg, CodeBlock): + result_messages.append({ + "type": "code", + "role": msg.role, + "code": msg.code, + }) + elif isinstance(msg, ToolCallBlock): + result_messages.append({ + "type": "tool_call", + "id": msg.id, + "name": msg.name, + "arguments": msg.arguments, + "response": msg.response, + }) + else: + result_messages.append({ + "type": "text", + "role": msg.role, + "content": msg.content, + }) + return {"messages": result_messages} + + def to_langchain_messages(self) -> list[BaseMessage]: + langchain_messages: list[BaseMessage] = [] + for msg in self.messages: + if isinstance(msg, CodeBlock): + langchain_messages.append(AIMessage(content=msg.code)) + + elif isinstance(msg, ToolCallBlock): + try: + ai_tool_call = { + "name": msg.name, + "args": json.loads(msg.arguments) + if isinstance(msg.arguments, str) and msg.arguments != "" + else {}, + "id": msg.id + } + except json.decoder.JSONDecodeError as e: + logging.error(e) + ai_tool_call = { + "name": msg.name, + "args": {}, + "id": msg.id + } + + langchain_messages.append(AIMessage( + content="", + tool_calls=[ai_tool_call] + )) + + langchain_messages.append(ToolMessage( + response=msg.response, + content=msg.content, + tool_call_id=msg.id, + name=msg.name + )) + + else: + if msg.role.lower() in ["user", "human"]: + langchain_messages.append(HumanMessage(content=msg.content)) + elif msg.role.lower() in ["system"]: + langchain_messages.append(SystemMessage(content=msg.content)) + else: + langchain_messages.append(AIMessage(content=msg.content)) + + return langchain_messages + + def get_messages_by_role(self, role: str) -> list[BaseBlock]: + return [msg for msg in self.messages if msg.role == role] + + def get_text_blocks(self) -> list[BaseBlock]: + return [ + msg + for msg in self.messages + if not (isinstance(msg, CodeBlock) or isinstance(msg, ToolCallBlock)) + ] + + def get_code_blocks(self) -> list[CodeBlock]: + return [msg for msg in self.messages if isinstance(msg, CodeBlock)] + + def get_tool_calls(self) -> list[ToolCallBlock]: + return [msg for msg in self.messages if isinstance(msg, ToolCallBlock)] + + +@dataclass_json +@dataclass +class DialogueState: + """ + Mutable state passed through the LangGraph workflow. + + Contains the dialogue history (`dialogue_sessions`), memory stores, and the final generated response. + """ + from src.algorithms.summarize_algorithms.core.memory_storage.memory_storage import ( + MemoryStorage, + ) + + dialogue_sessions: list[Session] + prepared_messages: list[BaseMessage] + code_memory_storage: MemoryStorage | None + tool_memory_storage: MemoryStorage | None + query: str + current_session_index: int = 0 + _response: str | dict[str, Any] | None = None + + @property + def response(self) -> str | dict[str, Any]: + if self._response is None: + raise ValueError("Response has not been generated yet.") + return self._response + + @property + def current_context(self) -> Session: + return self.dialogue_sessions[-1] + + +@dataclass_json +@dataclass +class MemoryDialogueState(DialogueState): + """`DialogueState` that also tracks the `last_session` explicitly (used by memory-based systems).""" + last_session: Session = field(default_factory=lambda: Session([])) + + +@dataclass_json +@dataclass +class RecsumDialogueState(MemoryDialogueState): + """Dialogue state for `RecsumDialogueSystem` (keeps iterative `text_memory`).""" + text_memory: list[list[str]] = field(default_factory=list) + + @property + def latest_memory(self) -> str: + return "\n".join(self.text_memory[-1]) if self.text_memory else "" + + +@dataclass_json +@dataclass +class MemoryBankDialogueState(MemoryDialogueState): + """Dialogue state for `MemoryBankDialogueSystem` (stores session summaries in `text_memory_storage`).""" + from src.algorithms.summarize_algorithms.core.memory_storage.memory_storage import ( + MemoryStorage, + ) + + text_memory_storage: MemoryStorage = field(default_factory=MemoryStorage) + + +class WorkflowNode(Enum): + """Named nodes in the LangGraph dialogue workflow.""" + UPDATE_MEMORY = "update_memory" + GENERATE_RESPONSE = "generate_response" + + +class UpdateState(Enum): + """Routing values used by `should_continue_memory_update` to control the graph loop.""" + CONTINUE_UPDATE = "continue_update" + FINISH_UPDATE = "finish_update" + + +@dataclass_json +@dataclass +class ResponseContext: + """Output of response generation (raw response + the prepared message history sent to the model).""" + response: Any + prepared_history: list[BaseMessage] diff --git a/src/summarize_algorithms/core/prompts.py b/src/algorithms/summarize_algorithms/core/prompts.py similarity index 94% rename from src/summarize_algorithms/core/prompts.py rename to src/algorithms/summarize_algorithms/core/prompts.py index faf9c2b..2732301 100644 --- a/src/summarize_algorithms/core/prompts.py +++ b/src/algorithms/summarize_algorithms/core/prompts.py @@ -2,7 +2,7 @@ RESPONSE_GENERATION_PROMPT = PromptTemplate.from_template( """ -You are an advanced AI agent specializing in working with code and technical tasks, +You are an advanced AI simple_algorithms specializing in working with code and technical tasks, but also capable of engaging in friendly, natural conversation. Your goal is to generate useful, accurate, and personalized responses using three types of input memory: diff --git a/src/algorithms/summarize_algorithms/core/response_generator.py b/src/algorithms/summarize_algorithms/core/response_generator.py new file mode 100644 index 0000000..dcc446a --- /dev/null +++ b/src/algorithms/summarize_algorithms/core/response_generator.py @@ -0,0 +1,236 @@ +import logging + +from typing import Any + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, + trim_messages, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable + +from src.algorithms.summarize_algorithms.core.models import ResponseContext, Session +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import TOOLS +from src.utils.system_prompt_builder import MemorySections, SystemPromptBuilder + + +class ResponseGenerator: + _MEMORY_MODE_BASELINE = "baseline" + _MEMORY_MODE_MEMORY = "memory" + """ + Generates the final assistant response given: + - the last dialogue session + - retrieved memory (code/tool/text) + - the current user query + + Depending on configuration, it can: + - return plain text (`StrOutputParser`) + - call tools (`bind_tools`) + - return structured JSON (`with_structured_output`) + """ + + def __init__( + self, + llm: BaseChatModel, + structure: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + *, + max_prompt_tokens: int | None = None, + ) -> None: + self._llm = llm + self._structure = structure + self._tools = tools + self._max_prompt_tokens = max_prompt_tokens + self._prompt_builder = SystemPromptBuilder() + self._chain = self._build_chain() + + def _build_chain(self) -> Runnable: + """Build the runnable used for response generation. + + `generate_response()` already assembles the full prompt as a list of messages + (`full_history`), so the chain must accept a `list[BaseMessage]` directly. + + The branching logic for structured output and tool binding is intentionally kept. + """ + if self._structure and not self._tools: + return self._llm.with_structured_output(self._structure) + + if self._tools and not self._structure: + return self._llm.bind_tools(self._tools) + + if self._tools and self._structure: + tools = [self._get_return_action_plan(), *self._tools] + return self._llm.bind_tools(tools) + + return self._llm | StrOutputParser() + + def _get_return_action_plan(self) -> dict[str, Any]: + """ + Build an auxiliary tool spec that forces the model to return the structured JSON according to `self._structure`. + + :return: dict[str, Any]: tool spec for `bind_tools`. + """ + return { + "type": "function", + "function": { + "name": "return_action_plan", + "description": "Return the final JSON action plan strictly matching the schema.", + "parameters": self._structure, + }, + } + + def _prepare_history(self, sessions: Session, user_query: str) -> list[BaseMessage]: + """ + Prepare the message history that will follow the unified system instruction. + + Ensures the last message in the returned list is the latest user request. + + :param sessions: conversation session used as history. + :param user_query: latest user request (may already be present as the last user message in `sessions`). + :return: list[BaseMessage]: prepared history messages (no SystemMessage). + """ + trimmed_history = trim_messages( + sessions.to_langchain_messages(), + token_counter=self._llm, + max_tokens=100000, + strategy="last", + include_system=False, + allow_partial=False, + ) + + if user_query.strip() != "": + should_append = True + if trimmed_history and isinstance(trimmed_history[-1], HumanMessage): + should_append = trimmed_history[-1].content != user_query + if should_append: + trimmed_history.append(HumanMessage(content=user_query)) + + return trimmed_history + + def _crop(self, messages: list[BaseMessage], max_tokens: int = 80000) -> list[BaseMessage]: + """Trim a list of messages to fit into a token budget. + + Mirrors `DialogueBaseline._crop()` behavior: + - preserves an initial `SystemMessage` (if present) + - ensures the first message after trimming is not a `ToolMessage` + """ + system_msg: SystemMessage | None = None + if messages and isinstance(messages[0], SystemMessage): + system_msg = messages[0] + messages_to_trim = messages[1:] + else: + messages_to_trim = messages + + total_tokens_before_crop = self._llm.get_num_tokens_from_messages(messages_to_trim) + logging.info(f"Total tokens before crop: {total_tokens_before_crop}") + + trimmed_messages = trim_messages( + messages_to_trim, + token_counter=self._llm, + max_tokens=max_tokens, + strategy="last", + include_system=True, + allow_partial=False, + ) + + while trimmed_messages and isinstance(trimmed_messages[0], ToolMessage): + trimmed_messages.pop(0) + + if system_msg is not None: + total_tokens_after_crop = self._llm.get_num_tokens_from_messages(trimmed_messages) + logging.info(f"Total tokens after crop (without system message): {total_tokens_after_crop}") + return [system_msg, *trimmed_messages] + + return trimmed_messages + + def _infer_memory_mode(self, memory: MemorySections) -> str: + """Infer memory mode for system prompt rendering. + + We consider the run as "memory" if any memory section is present and non-empty. + """ + sections = [ + memory.recap, + memory.memory_bank, + memory.code_knowledge, + memory.tool_memory, + ] + has_memory = any((s or "").strip() != "" for s in sections) + return self._MEMORY_MODE_MEMORY if has_memory else self._MEMORY_MODE_BASELINE + + def _build_system_message( + self, + *, + memory: MemorySections, + ) -> SystemMessage: + """ + Build the single unified SystemMessage from Jinja2 templates. + + Order: + 1) introduction.j2 + 2) memory injection (conditional) + 3) schema_and_tool.j2 + 4) bridge_to_conversation.j2 + + :param memory: memory sections to inject. + :return: SystemMessage: unified system instruction. + """ + system_prompt_text = self._prompt_builder.build( + schema=self._structure, + tools=TOOLS, + memory=memory, + memory_mode=self._infer_memory_mode(memory), + ) + return SystemMessage(content=system_prompt_text) + + def generate_response( + self, + *, + last_session: Session, + user_query: str, + memory: MemorySections, + ) -> ResponseContext: + """Generate a response using a single unified SystemMessage followed by the conversation history. + + :param last_session: the current conversation session (history used for response generation). + :param user_query: latest user request (must become the last HumanMessage). + :param memory: memory sections to inject into the system instruction. + :return: ResponseContext: raw model output and the prepared history sent to the model. + """ + try: + history_messages: list[BaseMessage] = self._prepare_history(last_session, user_query) + + history_messages = [m for m in history_messages if not isinstance(m, SystemMessage)] + + system_message = self._build_system_message(memory=memory) + + final_prompt: list[BaseMessage] = [system_message, *history_messages] + + # Logging mirrors `DialogueBaseline`: show prompt token counts before/after crop. + system_tokens = self._llm.get_num_tokens_from_messages([system_message]) + history_tokens = self._llm.get_num_tokens_from_messages(history_messages) + total_tokens = self._llm.get_num_tokens_from_messages(final_prompt) + logging.info( + "Prompt tokens breakdown: system=%s history=%s total=%s", + system_tokens, + history_tokens, + total_tokens, + ) + + prompt_to_invoke = final_prompt + if self._max_prompt_tokens is not None: + prompt_to_invoke = self._crop(final_prompt, max_tokens=self._max_prompt_tokens) + logging.info( + "Prompt tokens after crop: total=%s", + self._llm.get_num_tokens_from_messages(prompt_to_invoke), + ) + + response = self._chain.invoke(prompt_to_invoke) + + return ResponseContext(response=response, prepared_history=prompt_to_invoke) + + except Exception as e: + raise ConnectionError(f"API request failed: {str(e)}") from e diff --git a/src/summarize_algorithms/memory_bank/__init__.py b/src/algorithms/summarize_algorithms/memory_bank/__init__.py similarity index 100% rename from src/summarize_algorithms/memory_bank/__init__.py rename to src/algorithms/summarize_algorithms/memory_bank/__init__.py diff --git a/src/algorithms/summarize_algorithms/memory_bank/dialogue_system.py b/src/algorithms/summarize_algorithms/memory_bank/dialogue_system.py new file mode 100644 index 0000000..df7c4a0 --- /dev/null +++ b/src/algorithms/summarize_algorithms/memory_bank/dialogue_system.py @@ -0,0 +1,57 @@ +from src.algorithms.summarize_algorithms.core.base_dialogue_system import ( + BaseDialogueSystem, +) +from src.algorithms.summarize_algorithms.core.memory_storage.memory_storage import ( + MemoryStorage, +) +from src.algorithms.summarize_algorithms.core.models import ( + MemoryBankDialogueState, + Session, +) +from src.algorithms.summarize_algorithms.memory_bank.prompts import ( + SESSION_SUMMARY_PROMPT, +) +from src.algorithms.summarize_algorithms.memory_bank.summarizer import SessionSummarizer + + +class MemoryBankDialogueSystem(BaseDialogueSystem): + """ + Implementation of the MemoryBank-style dialogue system. + + Summarizes each session into a compact representation stored in `text_memory_storage` and optionally augments + the prompt with retrieved code/tool memories (FAISS + embeddings). + """ + + def _build_summarizer(self) -> SessionSummarizer: + return SessionSummarizer(self.memory_llm, SESSION_SUMMARY_PROMPT) + + def _get_initial_state( + self, sessions: list[Session], last_session: Session, query: str + ) -> MemoryBankDialogueState: + return MemoryBankDialogueState( + dialogue_sessions=sessions, + last_session=last_session, + code_memory_storage=( + MemoryStorage( + embeddings=self.embed_model, max_session_id=self.max_session_id + ) + if self.embed_code + else None + ), + tool_memory_storage=( + MemoryStorage( + embeddings=self.embed_model, max_session_id=self.max_session_id + ) + if self.embed_tool + else None + ), + query=query, + text_memory_storage=MemoryStorage( + embeddings=self.embed_model, max_session_id=self.max_session_id + ), + prepared_messages=[] + ) + + @property + def _get_dialogue_state_class(self) -> type: + return MemoryBankDialogueState diff --git a/src/summarize_algorithms/memory_bank/prompts.py b/src/algorithms/summarize_algorithms/memory_bank/prompts.py similarity index 100% rename from src/summarize_algorithms/memory_bank/prompts.py rename to src/algorithms/summarize_algorithms/memory_bank/prompts.py diff --git a/src/summarize_algorithms/memory_bank/summarizer.py b/src/algorithms/summarize_algorithms/memory_bank/summarizer.py similarity index 56% rename from src/summarize_algorithms/memory_bank/summarizer.py rename to src/algorithms/summarize_algorithms/memory_bank/summarizer.py index f62eeed..fb2d019 100644 --- a/src/summarize_algorithms/memory_bank/summarizer.py +++ b/src/algorithms/summarize_algorithms/memory_bank/summarizer.py @@ -3,15 +3,24 @@ from langchain_core.runnables import RunnableSerializable from pydantic import BaseModel, Field -from src.summarize_algorithms.core.base_summarizer import BaseSummarizer -from src.summarize_algorithms.core.models import BaseBlock +from src.algorithms.summarize_algorithms.core.base_summarizer import BaseSummarizer +from src.algorithms.summarize_algorithms.core.models import BaseBlock class SessionMemory(BaseModel): + """Pydantic schema for structured LLM output used by memory summarizers.""" + summary_messages: list[BaseBlock] = Field(description="Summary of session messages") class SessionSummarizer(BaseSummarizer): + """ + MemoryBank session summarizer. + + Summarizes a single session into a list of `BaseBlock` messages ("memory") which can later be retrieved and + injected into the response generation prompt. + """ + def _build_chain(self) -> RunnableSerializable[dict[str, Any], SessionMemory]: return cast( RunnableSerializable[dict, SessionMemory], @@ -19,6 +28,13 @@ def _build_chain(self) -> RunnableSerializable[dict[str, Any], SessionMemory]: ) def summarize(self, session_messages: str, session_id: int) -> list[BaseBlock]: + """ + Summarize one session into memory blocks. + + :param session_messages: stringified session content. + :param session_id: identifier injected into the prompt (useful for tracing/debugging). + :return: list[BaseBlock]: messages representing the session summary. + """ try: response = self.chain.invoke( { diff --git a/src/algorithms/summarize_algorithms/recsum/__init__.py b/src/algorithms/summarize_algorithms/recsum/__init__.py new file mode 100644 index 0000000..7082f6a --- /dev/null +++ b/src/algorithms/summarize_algorithms/recsum/__init__.py @@ -0,0 +1,6 @@ +"""RecSum implementation (compatibility exports).""" + +from src.algorithms.summarize_algorithms.recsum.summarizer import RecursiveSummarizer + +__all__ = ["RecursiveSummarizer"] + diff --git a/src/algorithms/summarize_algorithms/recsum/dialogue_system.py b/src/algorithms/summarize_algorithms/recsum/dialogue_system.py new file mode 100644 index 0000000..1aacf7f --- /dev/null +++ b/src/algorithms/summarize_algorithms/recsum/dialogue_system.py @@ -0,0 +1,44 @@ + +from src.algorithms.summarize_algorithms.core.base_dialogue_system import ( + BaseDialogueSystem, +) +from src.algorithms.summarize_algorithms.core.memory_storage.memory_storage import ( + MemoryStorage, +) +from src.algorithms.summarize_algorithms.core.models import RecsumDialogueState, Session +from src.algorithms.summarize_algorithms.recsum.prompts import ( + MEMORY_UPDATE_PROMPT_TEMPLATE, +) +from src.algorithms.summarize_algorithms.recsum.summarizer import RecursiveSummarizer + + +class RecsumDialogueSystem(BaseDialogueSystem): + """ + Implementation of the RecSum-style dialogue system. + + Uses `RecursiveSummarizer` to iteratively update a text memory (and optionally vector-retrieved code/tool memory) + and then generates a final response via `BaseDialogueSystem`’s graph. + """ + + def _build_summarizer(self) -> RecursiveSummarizer: + return RecursiveSummarizer(self.memory_llm, MEMORY_UPDATE_PROMPT_TEMPLATE) + + def _get_initial_state( + self, sessions: list[Session], last_session: Session, query: str + ) -> RecsumDialogueState: + return RecsumDialogueState( + dialogue_sessions=sessions, + last_session=last_session, + code_memory_storage=MemoryStorage( + embeddings=self.embed_model, max_session_id=self.max_session_id + ), + tool_memory_storage=MemoryStorage( + embeddings=self.embed_model, max_session_id=self.max_session_id + ), + query=query, + prepared_messages=[] + ) + + @property + def _get_dialogue_state_class(self) -> type: + return RecsumDialogueState diff --git a/src/summarize_algorithms/recsum/prompts.py b/src/algorithms/summarize_algorithms/recsum/prompts.py similarity index 100% rename from src/summarize_algorithms/recsum/prompts.py rename to src/algorithms/summarize_algorithms/recsum/prompts.py diff --git a/src/summarize_algorithms/recsum/summarizer.py b/src/algorithms/summarize_algorithms/recsum/summarizer.py similarity index 53% rename from src/summarize_algorithms/recsum/summarizer.py rename to src/algorithms/summarize_algorithms/recsum/summarizer.py index a2f3c6d..e5888f8 100644 --- a/src/summarize_algorithms/recsum/summarizer.py +++ b/src/algorithms/summarize_algorithms/recsum/summarizer.py @@ -2,12 +2,19 @@ from langchain_core.runnables import RunnableSerializable -from src.summarize_algorithms.core.base_summarizer import BaseSummarizer -from src.summarize_algorithms.core.models import BaseBlock -from src.summarize_algorithms.memory_bank.summarizer import SessionMemory +from src.algorithms.summarize_algorithms.core.base_summarizer import BaseSummarizer +from src.algorithms.summarize_algorithms.core.models import BaseBlock +from src.algorithms.summarize_algorithms.memory_bank.summarizer import SessionMemory class RecursiveSummarizer(BaseSummarizer): + """ + RecSum memory updater. + + Given the previous memory + the current dialogue context, produces a list of `BaseBlock` messages representing + the updated memory. + """ + def _build_chain(self) -> RunnableSerializable[dict[str, Any], SessionMemory]: return cast( RunnableSerializable[dict, SessionMemory], @@ -15,6 +22,13 @@ def _build_chain(self) -> RunnableSerializable[dict[str, Any], SessionMemory]: ) def summarize(self, previous_memory: str, dialogue_context: str) -> list[BaseBlock]: + """ + Update recursive memory given previous memory and the latest dialogue context. + + :param previous_memory: previous memory string. + :param dialogue_context: current dialogue context string. + :return: list[BaseBlock]: updated memory blocks. + """ try: response = self.chain.invoke( { diff --git a/src/summarize_algorithms/recsum/__init__.py b/src/benchmark/__init__.py similarity index 100% rename from src/summarize_algorithms/recsum/__init__.py rename to src/benchmark/__init__.py diff --git a/src/benchmark/logger/__init__.py b/src/benchmark/logger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/logger/base_logger.py b/src/benchmark/logger/base_logger.py new file mode 100644 index 0000000..4685926 --- /dev/null +++ b/src/benchmark/logger/base_logger.py @@ -0,0 +1,100 @@ +import logging +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +from src.algorithms.summarize_algorithms.core.models import ( + DialogueState, + MemoryBankDialogueState, + RecsumDialogueState, + Session, +) +from src.benchmark.models.dtos import BaseRecord, MetricState +from src.benchmark.utils.json_log_utils import JsonLogUtils + + +class BaseLogger(ABC): + def __init__(self, logs_dir: str | Path = "logs/memory") -> None: + os.makedirs(logs_dir, exist_ok=True) + self.log_dir = Path(logs_dir) + self.logger = logging.getLogger(__name__) + + @abstractmethod + def log_iteration( + self, + system_name: str, + query: str, + iteration: int, + sessions: list[Session], + state: DialogueState, + subdirectory: Path, + metrics: list[MetricState] | None = None, + save: bool = True, + ) -> BaseRecord: + ... + + @abstractmethod + def fetch_logs( + self, + system_names: list[str], + subdirectory: Path, + ) -> list[BaseRecord]: + ... + + def _prepare_and_save_log( + self, + record: dict[str, Any], + subdirectory: Path, + system_name: str, + iteration: int, + metrics: list[MetricState] | None = None + ) -> None: + if metrics is not None: + record["metric"] = BaseLogger.metrics_to_dicts(metrics) + + if subdirectory is not None: + directory: Path = self.log_dir / subdirectory + os.makedirs(directory, exist_ok=True) + else: + directory = self.log_dir + + self.save_log_dict(directory, iteration, record, system_name) + + def save_log_dict( + self, + directory: Path, + iteration: int, + record: dict[str, Any], + system_name: str + ) -> None: + # Use overwrite mode: in "evaluate-by-logs" we update the same file in-place. + # (Previous append mode could create multiple JSON objects in one file.) + target = directory / (system_name + "-" + str(record["timestamp"]) + ".json") + JsonLogUtils.write(target, record) + + self.logger.info(f"Saved successfully iteration {iteration} to {self.log_dir}") + + @staticmethod + def metrics_to_dicts(metrics: list[MetricState]) -> list[dict[str, Any]]: + return [ + {"metric_name": metric.metric_name.value, "metric_value": metric.metric_value} + for metric in metrics + ] + + @staticmethod + def _serialize_memories( + state: DialogueState + ) -> dict[str, Any]: + result: dict[str, Any] = {} + + if state.code_memory_storage is not None: + result["code_memory_storage"] = state.code_memory_storage.to_dict() + if state.tool_memory_storage is not None: + result["tool_memory_storage"] = state.tool_memory_storage.to_dict() + if isinstance(state, MemoryBankDialogueState): + result["text_memory_storage"] = state.text_memory_storage.to_dict() + if isinstance(state, RecsumDialogueState): + result["text_memory"] = state.text_memory + + return result diff --git a/src/benchmark/logger/baseline_logger.py b/src/benchmark/logger/baseline_logger.py new file mode 100644 index 0000000..b413866 --- /dev/null +++ b/src/benchmark/logger/baseline_logger.py @@ -0,0 +1,67 @@ +from datetime import datetime +from pathlib import Path + +from typing_extensions import override + +from src.algorithms.summarize_algorithms.core.models import DialogueState, Session +from src.benchmark.logger.base_logger import BaseLogger +from src.benchmark.models.dtos import BaseRecord, MetricState +from src.benchmark.utils.json_log_utils import JsonLogUtils + + +class BaselineLogger(BaseLogger): + @override + def log_iteration( + self, + system_name: str, + query: str, + iteration: int, + sessions: list[Session], + state: DialogueState, + subdirectory: Path, + metrics: list[MetricState] | None = None, + save: bool = True, + ) -> BaseRecord: + self.logger.info(f"Logging iteration {iteration} to {self.log_dir}") + + record = { + "timestamp": datetime.now().isoformat(), + "iteration": iteration, + "system": system_name, + "query": query, + "response": getattr(state, "response", None), + "sessions": [s.to_dict() for s in sessions], + "prepared_messages": [s.model_dump(mode="json") for s in state.prepared_messages], + } + + if save: + self._prepare_and_save_log(record, subdirectory, system_name, iteration, metrics) + + return BaseRecord.from_dict(record) + + @override + def fetch_logs( + self, + system_names: list[str], + subdirectory: Path, + ) -> list[BaseRecord]: + """Load saved benchmark log records. + + Logs are expected under: + `///*.json`. + + Returns records parsed via `BaseRecord.from_dict()`. + """ + records: list[BaseRecord] = [] + for system_name in system_names: + directory = self.log_dir / system_name / subdirectory + if not directory.exists(): + continue + + for path in sorted(directory.glob("*.json")): + # Logs are written with `indent=4`, so each record is multi-line JSON. + # Also keep compatibility with multiple objects appended to the same file. + for payload in JsonLogUtils.load_log_payloads(path): + records.append(BaseRecord.from_dict(payload)) + + return records diff --git a/src/benchmark/logger/memory_logger.py b/src/benchmark/logger/memory_logger.py new file mode 100644 index 0000000..1d2fc31 --- /dev/null +++ b/src/benchmark/logger/memory_logger.py @@ -0,0 +1,72 @@ +from datetime import datetime +from pathlib import Path + +from typing_extensions import override + +from src.algorithms.summarize_algorithms.core.models import DialogueState, Session +from src.benchmark.logger.base_logger import BaseLogger +from src.benchmark.models.dtos import BaseRecord, MemoryRecord, MetricState +from src.benchmark.utils.json_log_utils import JsonLogUtils + + +class MemoryLogger(BaseLogger): + @override + def log_iteration( + self, + system_name: str, + query: str, + iteration: int, + sessions: list[Session], + state: DialogueState, + subdirectory: Path, + metrics: list[MetricState] | None = None, + save: bool = True, + ) -> MemoryRecord: + self.logger.info(f"Logging iteration {iteration} to {self.log_dir}") + + if state is None: + raise ValueError("'state' argument is necessary for MemoryLogger") + + record = { + "timestamp": datetime.now().isoformat(), + "iteration": iteration, + "system": system_name, + "query": query, + "response": getattr(state, "response", None), + "memory": MemoryLogger._serialize_memories(state), + "sessions": [s.to_dict() for s in sessions], + "prepared_messages": [s.model_dump(mode="json") for s in state.prepared_messages], + } + + if save: + self._prepare_and_save_log(record, subdirectory, system_name, iteration, metrics) + + return MemoryRecord.from_dict(record) + + @override + def fetch_logs( + self, + system_names: list[str], + subdirectory: Path, + ) -> list[BaseRecord]: + """Load saved benchmark log records. + + Logs are expected under: + `///*.json`. + + Returns records parsed via `MemoryRecord.from_dict()`. + """ + records: list[BaseRecord] = [] + for system_name in system_names: + directory = self.log_dir / system_name / subdirectory + if not directory.exists(): + continue + + for path in sorted(directory.glob("*.json")): + for payload in JsonLogUtils.load_log_payloads(path): + if payload.get("memory") is None: + records.append(BaseRecord.from_dict(payload)) + else: + records.append(MemoryRecord.from_dict(payload)) + + return records diff --git a/src/benchmark/models/__init__.py b/src/benchmark/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/models/dtos.py b/src/benchmark/models/dtos.py new file mode 100644 index 0000000..21764a9 --- /dev/null +++ b/src/benchmark/models/dtos.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Any + +from dataclasses_json import DataClassJsonMixin, dataclass_json + +from src.algorithms.summarize_algorithms.core.models import BaseBlock, OpenAIModels +from src.benchmark.models.enums import MetricType + + +@dataclass +class DividedSession: + reference: list[BaseBlock] + past_interactions: list[BaseBlock] + + +@dataclass_json +@dataclass +class MetricState: + metric_name: MetricType + metric_value: float | int | bool | Decimal + + +@dataclass +class BaseRecord(DataClassJsonMixin): + timestamp: str + iteration: int + system: str + query: str + response: Any + sessions: list[dict[str, Any]] + prepared_messages: list[dict[str, Any]] = field(default_factory=list) + metric: list[MetricState] | None = field(default=None) + + +@dataclass +class MemoryRecord(BaseRecord): + memory: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Evaluation: + memory: dict[str, Any] = field(default_factory=dict) + + +@dataclass_json +@dataclass(frozen=True) +class ModelPrice: + input_per_million: Decimal + output_per_million: Decimal + + +MODEL_PRICES: dict[OpenAIModels, ModelPrice] = { + OpenAIModels.GPT_4_O: ModelPrice( + input_per_million=Decimal("2.50"), + output_per_million=Decimal("10.00"), + ), + OpenAIModels.GPT_5_MINI: ModelPrice( + input_per_million=Decimal("0.250"), + output_per_million=Decimal("2.000"), + ), + OpenAIModels.GPT_5_NANO: ModelPrice( + input_per_million=Decimal("0.05"), + output_per_million=Decimal("0.40"), + ), + OpenAIModels.GPT_4_1_MINI: ModelPrice( + input_per_million=Decimal("0.40"), + output_per_million=Decimal("1.60"), + ), + OpenAIModels.GPT_4_1: ModelPrice( + input_per_million=Decimal("2.00"), + output_per_million=Decimal("8.00"), + ), + OpenAIModels.GPT_3_5_TURBO: ModelPrice( + input_per_million=Decimal("0.50"), + output_per_million=Decimal("1.50"), + ), + OpenAIModels.GPT_4_O_MINI: ModelPrice( + input_per_million=Decimal("0.15"), + output_per_million=Decimal("0.60"), + ), +} + + +@dataclass +class TokenInfo(DataClassJsonMixin): + model: OpenAIModels + price: ModelPrice + input_tokens: int + output_tokens: int + input_price: Decimal + output_price: Decimal + total_price: Decimal diff --git a/src/benchmark/models/enums.py b/src/benchmark/models/enums.py new file mode 100644 index 0000000..674e357 --- /dev/null +++ b/src/benchmark/models/enums.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class MetricType(Enum): + COHERENCE = "COHERENCE" + F1_TOOL_STRICT = "F1_TOOL_STRICT" + F1_TOOL_ARGUMENTS_SIMILARITY = "F1_TOOL_ARGUMENTS_SIMILARITY" + F1_TOOL = "F1_TOOL" + + +class AlgorithmName(Enum): + BASE_RECSUM = "base_recsum" + BASE_MEMORY_BANK = "base_memory_bank" + RAG_RECSUM = "rag_recsum" + RAG_MEMORY_BANK = "rag_memory_bank" + FULL_BASELINE = "full_baseline" + LAST_BASELINE = "last_baseline" + SHORT_TOOLS = "short_tools" + WEIGHTS = "weights" + + +class AlgorithmDirectory(Enum): + BASE_RECSUM = "BaseRecsum" + BASE_MEMORY_BANK = "BaseMemoryBank" + RAG_RECSUM = "RagRecsum" + RAG_MEMORY_BANK = "RagMemoryBank" + FULL_BASELINE = "FullBaseline" + LAST_BASELINE = "LastBaseline" + SHORT_TOOLS = "ShortTools" + WEIGHTS = "Weights" diff --git a/src/benchmark/simple_benchmarking/__init__.py b/src/benchmark/simple_benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/simple_benchmarking/agent_chat/__init__.py b/src/benchmark/simple_benchmarking/agent_chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmarking/agent_chat/calculate_agent_chat_response_metrics.py b/src/benchmark/simple_benchmarking/agent_chat/calculate_agent_chat_response_metrics.py similarity index 92% rename from src/benchmarking/agent_chat/calculate_agent_chat_response_metrics.py rename to src/benchmark/simple_benchmarking/agent_chat/calculate_agent_chat_response_metrics.py index 90edf6d..7a32272 100644 --- a/src/benchmarking/agent_chat/calculate_agent_chat_response_metrics.py +++ b/src/benchmark/simple_benchmarking/agent_chat/calculate_agent_chat_response_metrics.py @@ -5,18 +5,22 @@ from dataclasses import dataclass, field from pathlib import Path -from src.benchmarking.agent_chat.deserialize_agent_chat import ChatDataset -from src.benchmarking.baseline import DialogueBaseline -from src.benchmarking.llm_evaluation import ( +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.models import Session +from src.algorithms.summarize_algorithms.memory_bank.dialogue_system import ( + MemoryBankDialogueSystem, +) +from src.algorithms.summarize_algorithms.recsum.dialogue_system import ( + RecsumDialogueSystem, +) +from src.benchmark.simple_benchmarking.agent_chat.deserialize_agent_chat import ( + ChatDataset, +) +from src.benchmark.simple_benchmarking.llm_evaluation import ( ComparisonResult, LLMChatAgentEvaluation, SingleChatAgentResult, ) -from src.summarize_algorithms.core.models import Session -from src.summarize_algorithms.memory_bank.dialogue_system import ( - MemoryBankDialogueSystem, -) -from src.summarize_algorithms.recsum.dialogue_system import RecsumDialogueSystem @dataclass @@ -62,15 +66,15 @@ def __init__(self) -> None: self.full_baseline = DialogueBaseline("FullBaseline") self.last_baseline = DialogueBaseline("LastBaseline") - self.path_to_save = Path("/Users/mikhailkharlamov/Documents/RecapKt/src/benchmarking/agent_chat/results") + self.path_to_save = Path("/Users/mikhailkharlamov/Documents/RecapKt/src/benchmark/agent_chat/results") def calculate(self) -> None: dialogue = self.dataset.sessions for i in range(len(dialogue)): self.logger.info(f"Processing dialogue {i + 1}/{len(dialogue)}") - self._process(dialogue[: i + 1], i + 1) + self._process(dialogue[: i + 1]) - def _process(self, sessions: list[Session], iteration: int) -> None: + def _process(self, sessions: list[Session]) -> None: last_session = sessions[-1] query = "" for i in range(len(last_session.messages) - 1, -1, -1): @@ -100,11 +104,11 @@ def _process(self, sessions: list[Session], iteration: int) -> None: ) self.logger.info("Started computing full session baseline response") full_sessions_baseline_response = self.full_baseline.process_dialogue( - sessions, query, iteration + sessions, query ) self.logger.info("Started computing last session baseline response") last_session_baseline_response = self.last_baseline.process_dialogue( - [sessions[-1]], query, iteration + [sessions[-1]], query ) self.logger.info("Started computing base recsum single response score") @@ -135,7 +139,6 @@ def _process(self, sessions: list[Session], iteration: int) -> None: assistant_answer=last_session_baseline_response.response, ) - self._single_eval_update( self.base_recsum_single_result, base_recsum_single_score ) @@ -284,8 +287,12 @@ def avg(lst: list[int]) -> float: ("Full Sessions Baseline", self.full_baseline), ("Last Session Baseline", self.last_baseline), ]: + pt = algo.prompt_tokens # type: ignore[attr-defined] + ct = algo.completion_tokens # type: ignore[attr-defined] + cost = algo.total_cost # type: ignore[attr-defined] + print( - f"{name:<25} | {algo.prompt_tokens:<15} | {algo.completion_tokens:<18} | {algo.total_cost:<12.5f}" # type: ignore + f"{name:<25} | {pt:<15} | {ct:<18} | {cost:<12.5f}" ) print("\n===Processed Messages ===") diff --git a/src/benchmarking/agent_chat/create_chat.py b/src/benchmark/simple_benchmarking/agent_chat/create_chat.py similarity index 93% rename from src/benchmarking/agent_chat/create_chat.py rename to src/benchmark/simple_benchmarking/agent_chat/create_chat.py index eaa9f0b..d2ea1c5 100644 --- a/src/benchmarking/agent_chat/create_chat.py +++ b/src/benchmark/simple_benchmarking/agent_chat/create_chat.py @@ -6,9 +6,9 @@ class ChatSessionCombiner: def __init__( - self, - file_list: list[str], - output_file: str = "combined_chat_history_sessions.json", + self, + file_list: list[str], + output_file: str = "combined_chat_history_sessions.json", ) -> None: self.file_list = file_list self.output_file = output_file @@ -34,7 +34,7 @@ def _load_chat_file(file_name: str) -> dict[str, Any] | None: return None def _create_session_entry( - self, file_name: str, data: dict[str, Any] + self, file_name: str, data: dict[str, Any] ) -> dict[str, Any]: session_id = self._extract_session_id(file_name) return {"session_id": session_id, "messages": data} diff --git a/src/benchmarking/agent_chat/deserialize_agent_chat.py b/src/benchmark/simple_benchmarking/agent_chat/deserialize_agent_chat.py similarity index 91% rename from src/benchmarking/agent_chat/deserialize_agent_chat.py rename to src/benchmark/simple_benchmarking/agent_chat/deserialize_agent_chat.py index 1032625..dc80db0 100644 --- a/src/benchmarking/agent_chat/deserialize_agent_chat.py +++ b/src/benchmark/simple_benchmarking/agent_chat/deserialize_agent_chat.py @@ -1,9 +1,10 @@ import json import re -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any -from src.summarize_algorithms.core.models import ( +from src.algorithms.summarize_algorithms.core.models import ( BaseBlock, CodeBlock, Session, @@ -23,7 +24,7 @@ def process_message(cls, message: dict[str, Any]) -> list[BaseBlock]: last_end = 0 for match in cls.CODE_PATTERN.finditer(message_text): - before_code = message_text[last_end : match.start()].strip() + before_code = message_text[last_end: match.start()].strip() code_content = match.group(1).strip() if before_code: @@ -63,7 +64,7 @@ def process_tool_calls(cls, messages: list[dict[str, Any]]) -> list[BaseBlock]: tool_calls = assistant_message.get("tool_calls", []) tool_responses = tool_message.get("tool_responses", []) - for tool_call, tool_response in zip(tool_calls, tool_responses): + for tool_call, tool_response in zip(tool_calls, tool_responses, strict=False): if tool_content is None: tool_content = ( f"name: {tool_call['name']}\narguments: {tool_call['arguments']}\n" @@ -106,9 +107,9 @@ def total_messages(self) -> int: @classmethod def from_file( - cls, - file_name: str = "/Users/mikhailkharlamov/Documents/RecapKt/src/benchmarking/agent_chat/" - "combined_chat_history_sessions.json", + cls, + file_name: str = "/Users/mikhailkharlamov/Documents/RecapKt/src/benchmark/agent_chat/" + "combined_chat_history_sessions.json", ) -> "ChatDataset": processor = MessageProcessor() sessions = [] diff --git a/src/benchmarking/calculate_mcp_memory_metrics.py b/src/benchmark/simple_benchmarking/calculate_mcp_memory_metrics.py similarity index 95% rename from src/benchmarking/calculate_mcp_memory_metrics.py rename to src/benchmark/simple_benchmarking/calculate_mcp_memory_metrics.py index c9e28e6..9c53105 100644 --- a/src/benchmarking/calculate_mcp_memory_metrics.py +++ b/src/benchmark/simple_benchmarking/calculate_mcp_memory_metrics.py @@ -3,8 +3,13 @@ from datetime import datetime from typing import Any -from src.benchmarking.llm_evaluation import LLMMemoryEvaluation -from src.benchmarking.metric_calculator import ( +from src.algorithms.summarize_algorithms.core.models import RecsumDialogueState +from src.algorithms.summarize_algorithms.memory_bank.dialogue_system import ( + MemoryBankDialogueState, + MemoryBankDialogueSystem, +) +from src.benchmark.simple_benchmarking.llm_evaluation import LLMMemoryEvaluation +from src.benchmark.simple_benchmarking.metric_calculator import ( CalculateMCPMetrics, MCPResponseResults, MetricStats, @@ -12,12 +17,7 @@ RawSemanticData, SystemResults, ) -from src.benchmarking.semantic_similarity import SemanticSimilarity -from src.summarize_algorithms.core.models import RecsumDialogueState -from src.summarize_algorithms.memory_bank.dialogue_system import ( - MemoryBankDialogueState, - MemoryBankDialogueSystem, -) +from src.utils.semantic_similarity import SemanticSimilarity class CalculateMCPMemoryMetrics(CalculateMCPMetrics): diff --git a/src/benchmarking/calculate_mcp_response_metrics.py b/src/benchmark/simple_benchmarking/calculate_mcp_response_metrics.py similarity index 91% rename from src/benchmarking/calculate_mcp_response_metrics.py rename to src/benchmark/simple_benchmarking/calculate_mcp_response_metrics.py index 45fcef0..de5ac60 100644 --- a/src/benchmarking/calculate_mcp_response_metrics.py +++ b/src/benchmark/simple_benchmarking/calculate_mcp_response_metrics.py @@ -3,9 +3,9 @@ from datetime import datetime from typing import Any -from src.benchmarking.baseline import DialogueBaseline -from src.benchmarking.llm_evaluation import LLMResponseEvaluation -from src.benchmarking.metric_calculator import ( +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.benchmark.simple_benchmarking.llm_evaluation import LLMResponseEvaluation +from src.benchmark.simple_benchmarking.metric_calculator import ( CalculateMCPMetrics, MCPResponseResults, MetricStats, @@ -13,7 +13,7 @@ RawSemanticData, SystemResults, ) -from src.benchmarking.semantic_similarity import SemanticSimilarity +from src.utils.semantic_similarity import SemanticSimilarity class CalculateMCPResponseMetrics(CalculateMCPMetrics): @@ -94,7 +94,7 @@ def _process_dialogue(self, dialogue: list, dialogue_index: int) -> None: recsum_response = self.recsum.process_dialogue( dialogue, query.content ).response - baseline_response = self.baseline.process_dialogue(dialogue, query.content, self.message_count) + baseline_response = self.baseline.process_dialogue(dialogue, query.content) self._update_semantic_scores( recsum_response, baseline_response.response, ideal_response.content @@ -112,7 +112,7 @@ def _process_dialogue(self, dialogue: list, dialogue_index: int) -> None: ) def _update_semantic_scores( - self, recsum_response: str, baseline_response: str, ideal_response: str + self, recsum_response: str | dict[str, Any], baseline_response: str | dict[str, Any], ideal_response: str ) -> None: recsum_score = self.semantic_scorer.compute_similarity( recsum_response, ideal_response @@ -129,7 +129,7 @@ def _update_semantic_scores( self._baseline_semantic_data.f1.append(baseline_score.f1) def _update_llm_single_scores( - self, recsum_response: str, baseline_response: str, context: str, memory: str + self, recsum_response: str | dict[str, Any], baseline_response: str | dict[str, Any], context: str, memory: str ) -> None: recsum_score = self.llm_scorer.evaluate_single( context=context, memory=memory, response=recsum_response @@ -148,7 +148,7 @@ def _update_llm_single_scores( self._baseline_llm_data.coherency.append(baseline_score.coherency_score) def _update_llm_pairwise_scores( - self, context: str, memory: str, recsum_response: str, baseline_response: str + self, context: str, memory: str, recsum_response: str | dict[str, Any], baseline_response: str | dict[str, Any] ) -> None: randomize_order = random.random() < 0.5 diff --git a/src/benchmarking/deserialize_mcp_data.py b/src/benchmark/simple_benchmarking/deserialize_mcp_data.py similarity index 96% rename from src/benchmarking/deserialize_mcp_data.py rename to src/benchmark/simple_benchmarking/deserialize_mcp_data.py index 539a4c2..6b18937 100644 --- a/src/benchmarking/deserialize_mcp_data.py +++ b/src/benchmark/simple_benchmarking/deserialize_mcp_data.py @@ -5,7 +5,7 @@ from datasets import load_dataset -from src.summarize_algorithms.core.models import BaseBlock, Session +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session @dataclass @@ -86,10 +86,10 @@ def _extract_sessions(dialogue_data: dict[str, list[Any]]) -> list[Session]: dialogue_sessions = dialogue_data.get("dialogue", []) speaker_sessions = dialogue_data.get("speaker", []) - for dialogue_msgs, speakers in zip(dialogue_sessions, speaker_sessions): + for dialogue_msgs, speakers in zip(dialogue_sessions, speaker_sessions, strict=False): messages = [ BaseBlock(role=speaker, content=message) - for message, speaker in zip(dialogue_msgs, speakers) + for message, speaker in zip(dialogue_msgs, speakers, strict=False) ] sessions.append(Session(messages)) diff --git a/src/benchmarking/llm_evaluation.py b/src/benchmark/simple_benchmarking/llm_evaluation.py similarity index 90% rename from src/benchmarking/llm_evaluation.py rename to src/benchmark/simple_benchmarking/llm_evaluation.py index 8946e49..4810646 100644 --- a/src/benchmarking/llm_evaluation.py +++ b/src/benchmark/simple_benchmarking/llm_evaluation.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Generic, Optional, TypeVar +from typing import Any, TypeVar from dotenv import load_dotenv from langchain_core.language_models import BaseChatModel @@ -11,7 +11,8 @@ from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field, SecretStr -from src.benchmarking.prompts import ( +from src.algorithms.summarize_algorithms.core.models import OpenAIModels +from src.benchmark.simple_benchmarking.prompts import ( PAIRWISE_EVALUATION_AGENT_RESPONSE, PAIRWISE_EVALUATION_MEMORY_PROMPT, PAIRWISE_EVALUATION_RESPONSE_PROMPT, @@ -19,7 +20,6 @@ SINGLE_EVALUATION_MEMORY_PROMPT, SINGLE_EVALUATION_RESPONSE_PROMPT, ) -from src.summarize_algorithms.core.models import OpenAIModels class ComparisonResult(Enum): @@ -80,14 +80,14 @@ class PairwiseChatAgentResult(BaseModel): PairwiseResultType = TypeVar("PairwiseResultType", bound=BaseModel) -class BaseLLMEvaluation(Generic[SingleResultType, PairwiseResultType], ABC): - def __init__(self, llm: Optional[BaseChatModel] = None) -> None: +class BaseLLMEvaluation[SingleResultType, PairwiseResultType](ABC): + def __init__(self, llm: BaseChatModel | None = None) -> None: load_dotenv() api_key: str | None = os.getenv("OPENAI_API_KEY") if api_key is not None: self.llm = llm or ChatOpenAI( - model=OpenAIModels.GPT_5_MINI.value, + model=OpenAIModels.GPT_4_O_MINI.value, #changed api_key=SecretStr(api_key)) else: raise ValueError("OPENAI_API_KEY environment variable is not loaded") @@ -123,7 +123,7 @@ def _build_pairwise_eval_chain(self) -> RunnableSerializable[dict[str, str], Any ) @staticmethod - def _safe_invoke(chain: RunnableSerializable, params: dict[str, str]) -> Any: + def _safe_invoke(chain: RunnableSerializable, params: dict[str, Any]) -> Any: try: return chain.invoke(params) except Exception as e: @@ -143,12 +143,12 @@ def _get_single_result_model(self) -> type[SingleResult]: def _get_pairwise_result_model(self) -> type[PairwiseResult]: return PairwiseResult - def evaluate_single(self, context: str, memory: str, response: str) -> SingleResult: + def evaluate_single(self, context: str, memory: str, response: str | dict[str, Any]) -> SingleResult: params = {"context": context, "memory": memory, "response": response} return self._safe_invoke(self.single_eval_chain, params) def evaluate_pairwise( - self, context: str, memory: str, first_response: str, second_response: str + self, context: str, memory: str, first_response: str | dict[str, Any], second_response: str | dict[str, Any] ) -> PairwiseResult: params = { "context": context, @@ -203,7 +203,7 @@ def _get_pairwise_result_model(self) -> type[PairwiseChatAgentResult]: return PairwiseChatAgentResult def evaluate_single( - self, dialogue_context: str, assistant_answer: str + self, dialogue_context: str, assistant_answer: str | dict[str, Any] ) -> SingleChatAgentResult: params = { "dialogue_context": dialogue_context, @@ -212,7 +212,7 @@ def evaluate_single( return self._safe_invoke(self.single_eval_chain, params) def evaluate_pairwise( - self, dialogue_context: str, first_answer: str, second_answer: str + self, dialogue_context: str, first_answer: str | dict[str, Any], second_answer: str | dict[str, Any] ) -> PairwiseChatAgentResult: params = { "dialogue_context": dialogue_context, diff --git a/src/benchmarking/metric_calculator.py b/src/benchmark/simple_benchmarking/metric_calculator.py similarity index 88% rename from src/benchmarking/metric_calculator.py rename to src/benchmark/simple_benchmarking/metric_calculator.py index 39266fa..eadc11c 100644 --- a/src/benchmarking/metric_calculator.py +++ b/src/benchmark/simple_benchmarking/metric_calculator.py @@ -4,29 +4,31 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any import numpy as np from pydantic import BaseModel -from src.benchmarking.deserialize_mcp_data import MCPDataset -from src.benchmarking.llm_evaluation import ComparisonResult -from src.summarize_algorithms.recsum.dialogue_system import RecsumDialogueSystem +from src.algorithms.summarize_algorithms.recsum.dialogue_system import ( + RecsumDialogueSystem, +) +from src.benchmark.simple_benchmarking.deserialize_mcp_data import MCPDataset +from src.benchmark.simple_benchmarking.llm_evaluation import ComparisonResult @dataclass class RawSemanticData: - precision: List[float] = field(default_factory=list) - recall: List[float] = field(default_factory=list) - f1: List[float] = field(default_factory=list) + precision: list[float] = field(default_factory=list) + recall: list[float] = field(default_factory=list) + f1: list[float] = field(default_factory=list) @dataclass class RawLLMData: - faithfulness: List[float] = field(default_factory=list) - informativeness: List[float] = field(default_factory=list) - coherency: List[float] = field(default_factory=list) + faithfulness: list[float] = field(default_factory=list) + informativeness: list[float] = field(default_factory=list) + coherency: list[float] = field(default_factory=list) @dataclass @@ -38,7 +40,7 @@ class MetricStats: count: int = 0 @classmethod - def from_values(cls, values: List[float]) -> "MetricStats": + def from_values(cls, values: list[float]) -> "MetricStats": if not values: return cls() @@ -65,13 +67,13 @@ class SystemResults: @dataclass class PairwiseResults: - faithfulness: Dict[str, int] = field( + faithfulness: dict[str, int] = field( default_factory=lambda: {"recsum": 0, "baseline": 0, "draw": 0} ) - informativeness: Dict[str, int] = field( + informativeness: dict[str, int] = field( default_factory=lambda: {"recsum": 0, "baseline": 0, "draw": 0} ) - coherency: Dict[str, int] = field( + coherency: dict[str, int] = field( default_factory=lambda: {"recsum": 0, "baseline": 0, "draw": 0} ) @@ -81,11 +83,11 @@ def get_total_count(self) -> int: @dataclass class MCPResult: - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) recsum_results: SystemResults = field(default_factory=SystemResults) pairwise_results: PairwiseResults = field(default_factory=PairwiseResults) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/src/benchmarking/prompts.py b/src/benchmark/simple_benchmarking/prompts.py similarity index 99% rename from src/benchmarking/prompts.py rename to src/benchmark/simple_benchmarking/prompts.py index 1d0846c..83ce3f2 100644 --- a/src/benchmarking/prompts.py +++ b/src/benchmark/simple_benchmarking/prompts.py @@ -72,7 +72,6 @@ """ ) - SINGLE_EVALUATION_MEMORY_PROMPT = PromptTemplate.from_template( """ You are a meticulous and impartial evaluator. Your task is to assess the quality of the `Generated Memory` @@ -201,7 +200,6 @@ one potential improvement). """) - PAIRWISE_EVALUATION_AGENT_RESPONSE = PromptTemplate.from_template(""" You are a highly critical expert evaluator comparing two AI assistant answers to the same user request in a dialogue that may contain: diff --git a/src/benchmark/tool_plan_benchmarking/__init__.py b/src/benchmark/tool_plan_benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/tool_plan_benchmarking/calculator.py b/src/benchmark/tool_plan_benchmarking/calculator.py new file mode 100644 index 0000000..589f1b1 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/calculator.py @@ -0,0 +1,205 @@ +import logging +from pathlib import Path +from typing import Any + +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + CodeBlock, + DialogueState, + Session, + ToolCallBlock, +) +from src.benchmark.logger.base_logger import BaseLogger +from src.benchmark.models.dtos import BaseRecord, MemoryRecord, MetricState +from src.benchmark.tool_plan_benchmarking.evaluators.base_evaluator import BaseEvaluator +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA + + +class Calculator: + """Run algorithms, evaluate results, and persist/re-evaluate benchmark logs.""" + + _LAST_LAUNCH_TIME_WINDOW_SECONDS: int = 120 + + @staticmethod + def evaluate( + algorithms: list[Dialogue], + evaluator_functions: list[BaseEvaluator], + sessions: list[Session], + prompt: str, + reference: list[BaseBlock], + logger: BaseLogger, + subdirectory: Path, + tools: list[dict[str, Any]] | None = None, + iteration: int | None = None, + ) -> list[BaseRecord]: + """Run and evaluate an algorithm, then save results via `logger`.""" + system_logger = logging.getLogger() + metrics: list[BaseRecord] = [] + + for algorithm in algorithms: + system_logger.info(f"Calculating {algorithm.system_name}") + state: DialogueState = algorithm.process_dialogue(sessions, prompt, PLAN_SCHEMA, tools) + + algorithm_metrics: list[MetricState] = Calculator.__evaluate_result( + evaluator_functions, + sessions, + prompt, + state, + reference, + ) + + record: BaseRecord = logger.log_iteration( + algorithm.system_name, + prompt, + iteration or 1, + sessions, + state, + Path(algorithm.system_name) / subdirectory, + algorithm_metrics, + ) + + metrics.append(record) + + return metrics + + @staticmethod + def evaluate_by_logs( + algorithms: list[Dialogue], + evaluator_functions: list[BaseEvaluator], + reference: list[BaseBlock], + logger: BaseLogger, + logs_path: Path | str, + subdirectory: Path, + iteration: int | None = None, + ) -> list[BaseRecord]: + """Append newly added metrics to existing log JSONs. + + Reads logs under `//`. + Updates files in-place by merging new metric values into the existing `metric` list. + """ + system_logger = logging.getLogger() + updated_records: list[BaseRecord] = [] + + for algorithm in algorithms: + old_logs: list[BaseRecord] = logger.fetch_logs( + system_names=[alg.system_name for alg in algorithms], + subdirectory=subdirectory, + ) + + system_logger.info(f"Calculating {algorithm.system_name} by logs") + + for log in old_logs: + state = DialogueState( + dialogue_sessions=[], + prepared_messages=[], + code_memory_storage=None, + tool_memory_storage=None, + query=log.query, + _response=log.response, + ) + + sessions: list[Session] = Calculator._deserialize_sessions(log.sessions) + + algorithm_metrics = Calculator.__evaluate_result( + evaluator_functions=evaluator_functions, + sessions=sessions, + prompt=log.query, + state=state, + reference=reference, + ) + + # Keep `metric` as `MetricState` objects (do not mix in dicts). + if log.metric is None: + log.metric = [] + + existing_metric_names = set() + for m in log.metric: + # Backward compatibility: tolerate dicts if they ever appear. + if isinstance(m, dict): + name = m.get("metric_name") + if name is not None: + existing_metric_names.add(name) + else: + existing_metric_names.add(m.metric_name) + + for metric in algorithm_metrics: + if metric.metric_name in existing_metric_names: + log.metric.pop(log.metric.index(metric)) + log.metric.append(metric) + existing_metric_names.add(metric.metric_name) + + logger.save_log_dict( + Path(logs_path) / Path(algorithm.system_name) / subdirectory, + iteration or 1, + log.to_dict(), + algorithm.system_name + ) + + updated_records.append(log) + + return updated_records + + @staticmethod + def _record_from_dict(record_dict: dict[str, Any]) -> BaseRecord: + if record_dict.get("memory") is None: + return BaseRecord.from_dict(record_dict) + return MemoryRecord.from_dict(record_dict) + + @staticmethod + def _deserialize_sessions(raw_sessions: Any) -> list[Session]: + if not isinstance(raw_sessions, list): + return [] + + sessions: list[Session] = [] + for raw in raw_sessions: + if not isinstance(raw, dict): + continue + sessions.append(Calculator._deserialize_session(raw)) + + return sessions + + @staticmethod + def _deserialize_session(raw_session: dict[str, Any]) -> Session: + messages: list[BaseBlock] = [] + + for msg in raw_session.get("messages", []): + if not isinstance(msg, dict): + continue + + msg_type = msg.get("type") + if msg_type == "text": + messages.append(BaseBlock(role=str(msg.get("role", "")), content=str(msg.get("content", "")))) + elif msg_type == "code": + code = str(msg.get("code", "")) + messages.append(CodeBlock(role=str(msg.get("role", "")), content=code, code=code)) + elif msg_type == "tool_call": + messages.append( + ToolCallBlock( + role="TOOL_RESPONSE", + content="", + id=str(msg.get("id", "")), + name=str(msg.get("name", "")), + arguments=str(msg.get("arguments", "")), + response=str(msg.get("response", "")), + ) + ) + else: + messages.append(BaseBlock(role=str(msg.get("role", "")), content=str(msg.get("content", "")))) + + return Session(messages) + + @staticmethod + def __evaluate_result( + evaluator_functions: list[BaseEvaluator], + sessions: list[Session], + prompt: str, + state: DialogueState, + reference: list[BaseBlock], + ) -> list[MetricState]: + algorithm_metrics: list[MetricState] = [] + for evaluator_function in evaluator_functions: + metric = evaluator_function.evaluate(sessions, prompt, state, reference) + algorithm_metrics.append(metric) + + return algorithm_metrics diff --git a/src/benchmark/tool_plan_benchmarking/evaluators/__init__.py b/src/benchmark/tool_plan_benchmarking/evaluators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmarking/tool_metrics/base_evaluator.py b/src/benchmark/tool_plan_benchmarking/evaluators/base_evaluator.py similarity index 80% rename from src/benchmarking/tool_metrics/base_evaluator.py rename to src/benchmark/tool_plan_benchmarking/evaluators/base_evaluator.py index b73583f..48b4af0 100644 --- a/src/benchmarking/tool_metrics/base_evaluator.py +++ b/src/benchmark/tool_plan_benchmarking/evaluators/base_evaluator.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from src.summarize_algorithms.core.models import ( +from src.algorithms.summarize_algorithms.core.models import ( BaseBlock, DialogueState, - MetricState, Session, ) +from src.benchmark.models.dtos import MetricState class BaseEvaluator(ABC): @@ -13,11 +13,14 @@ class BaseEvaluator(ABC): Base class for functions that evaluates llm's memory algorithms. """ + def __init__(self, mode: str | None = None): + self._mode = mode + @abstractmethod def evaluate( self, sessions: list[Session], - query: BaseBlock, + query: str, state: DialogueState, reference: list[BaseBlock] | None = None ) -> MetricState: diff --git a/src/benchmark/tool_plan_benchmarking/evaluators/f1_tool_evaluator.py b/src/benchmark/tool_plan_benchmarking/evaluators/f1_tool_evaluator.py new file mode 100644 index 0000000..4dfb5a6 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/evaluators/f1_tool_evaluator.py @@ -0,0 +1,169 @@ +import json + +from decimal import Decimal +from typing import Any + +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + DialogueState, + Session, + ToolCallBlock, +) +from src.benchmark.models.dtos import MetricState +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.evaluators.base_evaluator import BaseEvaluator +from src.utils.semantic_similarity import SemanticSimilarity + + +class F1ToolEvaluator(BaseEvaluator): + """ + Computes F1 between tools predicted by the model and tools used in a reference trace. + + The model is expected to return a structured response with a `plan_steps` list where tool calls are represented + as entries with `kind == "tool_call"`. + + Modes: + - default: compares only tool names + - "strict": compares tool names + exact JSON arguments + """ + + def evaluate( + self, + sessions: list[Session], + query: str, + state: DialogueState, + reference: list[BaseBlock] | None = None, + ) -> MetricState: + """ + Compute the F1 score for tool selection against a reference trace. + + :param sessions: previous sessions (unused here, but part of the evaluator interface). + :param query: the evaluated user query. + :param state: algorithm output state containing the model response. + :param reference: reference blocks containing expected tool calls. + :return: MetricState: metric name and computed value. + """ + if reference is None: + raise ValueError("Reference is required for F1 Tool evaluation.") + + if isinstance(state.response, str): + raise ValueError("State response must be a structured object (dict), not a string.") + + reference_tools: set[str] = { + tool.name + for tool in reference + if isinstance(tool, ToolCallBlock) + } + + plan_steps = state.response.get("plan_steps", []) + + if self._mode == "strict": + predicted_tools = F1ToolEvaluator._get_strict_matches(plan_steps, reference) + metric_type = MetricType.F1_TOOL_STRICT + elif self._mode == "arguments_similarity": + predicted_tools = F1ToolEvaluator._get_strict_matches(plan_steps, reference) + metric_type = MetricType.F1_TOOL_ARGUMENTS_SIMILARITY + else: + predicted_tools = F1ToolEvaluator._get_simple_matches(plan_steps) + metric_type = MetricType.F1_TOOL + + true_positives = len(predicted_tools.intersection(reference_tools)) + false_positives = len(predicted_tools.difference(reference_tools)) + false_negatives = len(reference_tools.difference(predicted_tools)) + + f1_score = F1ToolEvaluator._calculate_f1(true_positives, false_positives, false_negatives) + + return MetricState( + metric_name=metric_type, + metric_value=f1_score + ) + + @staticmethod + def _get_simple_matches(plan_steps: list[dict[str, Any]]) -> set[str]: + return { + step.get("name", "") + for step in plan_steps + if step.get("kind") == "tool_call" + } + + @staticmethod + def _get_strict_matches( + plan_steps: list[dict[str, Any]], + reference: list[BaseBlock] + ) -> set[str]: + matches = set() + + ref_tool_blocks = [r for r in reference if isinstance(r, ToolCallBlock)] + + for step in plan_steps: + if step.get("kind") != "tool_call": + continue + + step_name = step.get("name", "") + step_args = step.get("args", {}) + + is_match = any( + r.name.lower() == step_name and + F1ToolEvaluator._compare_arguments(step_args, json.loads(r.arguments)) + for r in ref_tool_blocks + ) + + if is_match: + matches.add(f"{step_name}|{step_args}") + + return matches + + @staticmethod + def _get_args_similarity_matches( + plan_steps: list[dict[str, Any]], + reference: list[BaseBlock] + ) -> set[str]: + arguments_similarity_threshold: float = 0.7 + + similarity = SemanticSimilarity() #TODO refactor it + matches: set[str] = set() + + for step in plan_steps: + if step.get("kind") != "tool_call": + continue + + step_name = str(step.get("name", "")).strip() + step_args = step.get("args", {}) + if not step_name: + continue + + for block in reference: + if isinstance(block, ToolCallBlock): + if block.name.lower() == step_name.lower(): + score = similarity.compare_json( + json.loads(block.arguments), + step_args + ) + if score >= arguments_similarity_threshold: + matches.add(f"{step_name}|{step_args}") + + return matches + + @staticmethod + def _calculate_f1(tp: int, fp: int, fn: int) -> Decimal: + """Compute F1 as an exact `Decimal`. + + Using Decimal avoids accumulating float rounding errors in downstream pipelines and is consistent with other + benchmark DTOs that already allow `Decimal` metric values. + """ + zero = Decimal("0") + if tp == 0: + return zero + + tp_d = Decimal(tp) + precision = tp_d / Decimal(tp + fp) + recall = tp_d / Decimal(tp + fn) + + if precision + recall == zero: + return zero + + return Decimal(2) * (precision * recall) / (precision + recall) + + @staticmethod + def _compare_arguments(args1: dict[str, Any], args2: dict[str, Any]) -> bool: + return args1 == args2 diff --git a/src/benchmark/tool_plan_benchmarking/evaluators/llm_as_a_judge/__init__.py b/src/benchmark/tool_plan_benchmarking/evaluators/llm_as_a_judge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/tool_plan_benchmarking/evaluators/llm_as_a_judge_base_evaluator.py b/src/benchmark/tool_plan_benchmarking/evaluators/llm_as_a_judge_base_evaluator.py new file mode 100644 index 0000000..853fd2f --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/evaluators/llm_as_a_judge_base_evaluator.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import os +from abc import abstractmethod +from typing import Any + +from dotenv import load_dotenv +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from pydantic import BaseModel, SecretStr + +from src.algorithms.summarize_algorithms.core.models import BaseBlock, DialogueState, OpenAIModels, Session +from src.benchmark.models.dtos import MetricState +from src.benchmark.tool_plan_benchmarking.evaluators.base_evaluator import BaseEvaluator +from src.utils.system_prompt_builder import MemorySections, SystemPromptBuilder + + +class LLMAsAJudgeBaseEvaluator(BaseEvaluator): + """Base class for evaluators that delegate metric computation to an LLM "judge".""" + + def __init__( + self, + mode: str | None = None, + llm: BaseChatModel | None = None, + ) -> None: + super().__init__(mode=mode) + + load_dotenv() + + # Allow passing a fake/mock LLM without requiring OPENAI_API_KEY. + if llm is not None: + self.llm = llm + else: + api_key: str | None = os.getenv("OPENAI_API_KEY") + if api_key is None: + raise ValueError("OPENAI_API_KEY environment variable is not loaded") + + self.llm = ChatOpenAI( + model=OpenAIModels.GPT_4_O_MINI.value, + api_key=SecretStr(api_key), + ) + + self._prompt_builder = SystemPromptBuilder() + self._system_message = SystemMessage(content=self._build_system_prompt()) + + def _build_system_prompt(self) -> str: + """Build a unified system prompt for the judge.""" + return self._prompt_builder.build( + schema=None, + tools=None, + memory=MemorySections(), + memory_mode="baseline", + examples=self._get_judge_examples(), + ) + + def _get_judge_examples(self) -> str: + return "" + + @abstractmethod + def _build_single_user_prompt(self, params: dict[str, Any]) -> str: + """Render the HumanMessage content for a single-option evaluation.""" + + @abstractmethod + def _build_pairwise_user_prompt(self, params: dict[str, Any]) -> str: + """Render the HumanMessage content for a pairwise evaluation.""" + + @abstractmethod + def _get_single_result_model(self) -> type[BaseModel]: + """Structured output model for single-option evaluation.""" + + @abstractmethod + def _get_pairwise_result_model(self) -> type[BaseModel]: + """Structured output model for pairwise evaluation.""" + + def _build_messages(self, user_prompt: str) -> list[BaseMessage]: + return [self._system_message, HumanMessage(content=user_prompt)] + + def _invoke_single(self, params: dict[str, Any]) -> BaseModel: + model = self._get_single_result_model() + chain = self.llm.with_structured_output(model) + return self._safe_invoke(chain, self._build_messages(self._build_single_user_prompt(params))) + + def _invoke_pairwise(self, params: dict[str, Any]) -> BaseModel: + model = self._get_pairwise_result_model() + chain = self.llm.with_structured_output(model) + return self._safe_invoke(chain, self._build_messages(self._build_pairwise_user_prompt(params))) + + @staticmethod + def _safe_invoke(chain: Any, messages: list[BaseMessage]) -> Any: + try: + return chain.invoke(messages) + except Exception as e: + raise ConnectionError(f"API request failed: {e}") from e + + @abstractmethod + def evaluate( + self, + sessions: list[Session], + query: str, + state: DialogueState, + reference: list[BaseBlock] | None = None, + ) -> MetricState: + ... diff --git a/src/benchmark/tool_plan_benchmarking/graphs/__init__.py b/src/benchmark/tool_plan_benchmarking/graphs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/tool_plan_benchmarking/graphs/box_plot.py b/src/benchmark/tool_plan_benchmarking/graphs/box_plot.py new file mode 100644 index 0000000..6a24590 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/graphs/box_plot.py @@ -0,0 +1,54 @@ +import seaborn as sns + +from matplotlib import pyplot as plt + +from src.benchmark.tool_plan_benchmarking.graphs.graph_builder import GraphBuilder +from src.benchmark.tool_plan_benchmarking.statistics.dtos import StatisticsDto + + +class BoxPlot(GraphBuilder): + """ + Box-plot visualization of metric distributions per algorithm. + + Useful for comparing variance and outliers across different dialogue systems/baselines. + """ + + @staticmethod + def build( + statistics: StatisticsDto, + path_to_save: str, + title: str = "", + ) -> None: + df = BoxPlot._runs_to_dataframe(statistics) + + sns.set_theme(style="whitegrid") + + plt.figure(figsize=(10, 6)) + ax = sns.boxplot( + data=df, + x="algorithm", + y="value", + width=0.6, + showfliers=True, + ) + + sns.stripplot( + data=df, + x="algorithm", + y="value", + color="black", + size=3, + alpha=0.6, + jitter=0.15, + ax=ax, + ) + + ax.set_title("Распределение значений " + title) + ax.set_xlabel("Алгоритм") + ax.set_ylabel(df["metric"].iloc[0]) # имя метрики из dto + + plt.xticks(rotation=20) + plt.tight_layout() + BoxPlot._save_figure(path_to_save) + plt.close() + diff --git a/src/benchmark/tool_plan_benchmarking/graphs/general_trends.py b/src/benchmark/tool_plan_benchmarking/graphs/general_trends.py new file mode 100644 index 0000000..c0e1b24 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/graphs/general_trends.py @@ -0,0 +1,64 @@ +import logging + +import seaborn as sns + +from matplotlib import pyplot as plt + +from src.benchmark.tool_plan_benchmarking.graphs.graph_builder import GraphBuilder +from src.benchmark.tool_plan_benchmarking.statistics.dtos import StatisticsDto + + +class GeneralTrends(GraphBuilder): + """ + Line-plot of metric trend as the number of sessions grows. + + For each algorithm, computes the mean score per `sessions` bucket and draws a curve. + """ + + @staticmethod + def build( + statistics: StatisticsDto, + path_to_save: str, + title: str = "", + ) -> None: + sns.set_theme(style="whitegrid") + + df = GeneralTrends._runs_to_dataframe(statistics) + + if df.empty: + logging.getLogger(__name__).warning( + "No runs found for graph '%s' (empty statistics). Skipping plot building.", + GeneralTrends.__name__, + ) + return + + grouped = df.groupby(["algorithm", "sessions"])["value"] + summary = grouped.mean().reset_index(name="mean") + + if summary.empty: + logging.getLogger(__name__).warning( + "No aggregated points for graph '%s'. Skipping plot building.", + GeneralTrends.__name__, + ) + return + + fig, ax = plt.subplots(figsize=(10, 6)) + + for algo, sub_df in summary.groupby("algorithm"): + sub = sub_df.sort_values("sessions") + + x = sub["sessions"].to_numpy() + y = sub["mean"].to_numpy().astype(float) + + ax.plot(x, y, marker="o", label=str(algo)) + + ax.set_xlabel("Число сессий") + ax.set_ylabel("F1_TOOL") + ax.set_title("Тенденции качества по числу сессий " + title) + ax.grid(True, which="major", axis="both", alpha=0.3) + ax.legend(title="Алгоритм") + + fig.tight_layout() + + GeneralTrends._save_figure(path_to_save) + plt.close(fig) diff --git a/src/benchmark/tool_plan_benchmarking/graphs/graph_builder.py b/src/benchmark/tool_plan_benchmarking/graphs/graph_builder.py new file mode 100644 index 0000000..303937d --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/graphs/graph_builder.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod + +import pandas as pd + +from matplotlib import pyplot as plt + +from src.benchmark.tool_plan_benchmarking.statistics.dtos import ( + AlgorithmStatistics, + StatisticsDto, +) + + +class GraphBuilder(ABC): + """ + Base class for graph/plot builders used in tool-metrics benchmark. + + Implementations take a `StatisticsDto` (a list of runs) and render a figure to disk. Helpers in this base class + convert runs into a Pandas `DataFrame` and save the active Matplotlib figure. + """ + + @staticmethod + @abstractmethod + def build(statistics: StatisticsDto, path_to_save: str, title: str = "") -> None: + """ + Render a graph for the provided statistics and save it to `path_to_save`. + + :param statistics: aggregated run statistics. + :param path_to_save: output image path. + :param title: optional title suffix. + :return: None + """ + ... + + @staticmethod + def _runs_to_dataframe(stats: StatisticsDto) -> pd.DataFrame: + algorithms: list[AlgorithmStatistics] = stats.algorithms + rows = [ + { + "algorithm": alg.name, + "metric": alg.metric.value, + "sessions": run.sessions, + "value": run.value, + } + for alg in algorithms + for run in alg.runs + ] + + # Ensure expected columns exist even when there are no rows. + columns = ["algorithm", "metric", "sessions", "value"] + return pd.DataFrame(rows, columns=columns) + + @staticmethod + def _save_figure(path: str = "graph.png") -> None: + # Some callers pass an empty string; treat it as "use default". + safe_path = path or "graph.png" + plt.savefig(safe_path, dpi=300, bbox_inches="tight") diff --git a/src/benchmark/tool_plan_benchmarking/graphs/trends_with_quantiles.py b/src/benchmark/tool_plan_benchmarking/graphs/trends_with_quantiles.py new file mode 100644 index 0000000..349c0b3 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/graphs/trends_with_quantiles.py @@ -0,0 +1,55 @@ +import pandas as pd +import seaborn as sns + +from matplotlib import pyplot as plt + +from src.benchmark.tool_plan_benchmarking.graphs.graph_builder import GraphBuilder +from src.benchmark.tool_plan_benchmarking.statistics.dtos import StatisticsDto + + +class TrendsWithQuantiles(GraphBuilder): + """ + Trend plot with uncertainty bands. + + Plots mean metric value over sessions and shades the inter-quantile band (default: 25th–75th percentile). + """ + + @staticmethod + def build(statistics: StatisticsDto, path_to_save: str, title: str = "") -> None: + sns.set_theme(style="whitegrid") + df = TrendsWithQuantiles._runs_to_dataframe(statistics) + summary = TrendsWithQuantiles.__summarize_for_bands(df) + + fig, ax = plt.subplots(figsize=(10, 6)) + + for algo, sub_df in summary.groupby("algorithm"): + sub = sub_df.sort_values("sessions") + + x = sub["sessions"].to_numpy() + y = sub["mean"].to_numpy().astype(float) + y_low = sub["q_low"].to_numpy().astype(float) + y_high = sub["q_high"].to_numpy().astype(float) + + ax.plot(x, y, marker="o", label=str(algo)) + ax.fill_between(x, y_low, y_high, alpha=0.2) + + ax.set_xlabel("Число сессий") + + metric_label = str(df["metric"].iloc[0]) if "metric" in df.columns and not df.empty else "Metric" + ax.set_ylabel(metric_label) + + ax.set_title("Тенденции качества " + title) + ax.legend() + fig.tight_layout() + TrendsWithQuantiles._save_figure(path_to_save) + plt.close(fig) + + @staticmethod + def __summarize_for_bands(df: pd.DataFrame, q_low: float = 0.25, q_high: float = 0.75) -> pd.DataFrame: + grouped = df.groupby(["algorithm", "sessions"])["value"] + summary = grouped.agg( + mean="mean", + q_low=lambda x: x.quantile(q_low), + q_high=lambda x: x.quantile(q_high), + ).reset_index() + return summary diff --git a/src/benchmark/tool_plan_benchmarking/load_session.py b/src/benchmark/tool_plan_benchmarking/load_session.py new file mode 100644 index 0000000..7d5ceea --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/load_session.py @@ -0,0 +1,227 @@ +import json + +from pathlib import Path +from typing import Any + +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + Session, + ToolCallBlock, +) + + +class Loader: + """ + Utilities for loading tool schemas and dialogue sessions from exported JSON formats used in benchmark. + + The benchmark code supports multiple dataset formats ("data_type_1" and "data_type_2"). This class converts + them into the project’s internal `Session`/`BaseBlock` representation. + """ + + @staticmethod + def load_func_tools(path: Path | str) -> list[dict[str, Any]]: + """ + Load tool/function specifications from a JSON file. + + :param path: path to JSON file. + :return: list[dict[str, Any]]: tool specs. + """ + with open(path, encoding="utf-8") as f: + data: list[dict[str, Any]] = json.load(f) + + return data + + @staticmethod + def load_session_data_type_2(path: Path | str) -> Session: + """ + Load a dialogue session from the "data_type_2" JSON export format. + + :param path: path to session JSON. + :return: Session: parsed session in internal format. + """ + result: list[BaseBlock] = [] + with open(path, encoding="utf-8") as f: + data = json.load(f) + + i = 0 + while i < len(data): + dict_block = data[i] + block_type = dict_block["type"] + + if block_type in ("user", "system", "USER", "SYSTEM"): + block = BaseBlock( + role=block_type.upper(), + content=dict_block["content"] + ) + result.append(block) + i += 1 + + elif block_type in ("assistant", "ASSISTANT"): + block = BaseBlock( + role=block_type.upper(), + content=dict_block["content"] + ) + result.append(block) + + tool_calls = dict_block.get("toolCalls") or dict_block.get("tool_calls", []) + if tool_calls: + i += 1 + blocks = Loader.__process_tool_calls_data_type_2( + tool_calls, + data[i].get("toolResponses") + or data[i].get("tool_responses", []) + ) + result.extend(blocks) + + i += 1 + + elif block_type in ("tool_response", "TOOL_RESPONSE", "tool", "TOOL"): + i += 1 + + else: + block = BaseBlock( + role=block_type.upper(), + content=str(dict_block.get("content", "")) + ) + result.append(block) + i += 1 + + return Session(result) + + @staticmethod + def load_session_data_type_1(path: Path | str) -> Session: + """ + Load a dialogue session from the "data_type_1" JSON export format. + + :param path: path to session JSON. + :return: Session: parsed session in internal format. + """ + with open(path, encoding="utf-8") as f: + raw = json.load(f) + + chat = raw.get("serializableChat", {}) + messages = chat.get("messages", []) + + result: list[BaseBlock] = [] + + for message in messages: + message_type: str = message.get("type", "") + + if message_type.endswith("SerializableMessage.UserMessage"): + prompt = message.get("prompt", "") + block = BaseBlock( + role="USER", + content=prompt, + ) + result.append(block) + + elif message_type.endswith("SerializableMessage.AssistantMessage"): + response = message.get("response", "") or "" + reasoning = message.get("reasoning", "") or "" + if reasoning.strip(): + content = f"{response}\n\n[reasoning]\n{reasoning}" + else: + content = response + + block = BaseBlock( + role="ASSISTANT", + content=content, + ) + result.append(block) + + elif message_type.endswith("SerializableMessage.ToolMessage"): + block = Loader.__process_tool_calls_data_type_1(message) + result.append(block) + + else: + block = BaseBlock( + role="UNKNOWN", + content=str(message), + ) + result.append(block) + + return Session(result) + + @staticmethod + def __process_tool_calls_data_type_1(message: dict[str, Any]) -> ToolCallBlock: + tool_call = message.get("toolCall") or {} + tool_resp = message.get("toolResponse") or {} + tool_id = tool_call.get("id", "") + name = tool_call.get("name", "") + arguments = tool_call.get("arguments", "") + response_result = tool_resp.get("result", "") + if response_result == "failure": + content = tool_resp.get("failure", "") or "" + else: + content = tool_resp.get("content", "") or "" + block = ToolCallBlock( + role="TOOL_RESPONSE", + content=content, + id=tool_id, + name=name, + arguments=arguments, + response=response_result, + ) + return block + + @staticmethod + def __process_tool_calls_data_type_2( + tool_calls: list[dict[str, Any]], + tool_responses: list[dict[str, dict[str, Any]]] + ) -> list[ToolCallBlock]: + blocks: list[ToolCallBlock] = [] + for call in tool_calls: + for tool_response in tool_responses: + if "response" in tool_response: + response: dict[str, Any] | None = tool_response.get("response") + elif "responseData" in tool_response: + response = json.loads(str(tool_response.get("responseData"))) + else: + continue + + if response is None: + continue + + if call["id"] == tool_response["id"]: + if response["result"] == "failure": + content = response["failure"] + else: + content = response["content"] + + block = ToolCallBlock( + role="TOOL_RESPONSE", + content=content, + id=call["id"], + name=call["name"], + arguments=call["arguments"], + response=response["result"] + ) + blocks.append(block) + return blocks + + +if __name__ == "__main__": + json_file_template: str = "*.json" + path_data_type_1: Path = Path("/Users/mikhailkharlamov/Documents/.../data_type_1") + for file in path_data_type_1.glob(json_file_template): + print(file) + session = Loader.load_session_data_type_1(file) + print(len(session), "- session") + tools = [] + for block in session: + if isinstance(block, ToolCallBlock): + tools.append(block) + print(len(tools), "- tools") + print(len(tools) / len(session)) + + path_data_type_2: Path = Path("/Users/mikhailkharlamov/Documents/.../data_type_2") + for file in path_data_type_2.glob(json_file_template): + print(file) + session = Loader.load_session_data_type_2(file) + print(len(session), "- session") + tools = [] + for block in session: + if isinstance(block, ToolCallBlock): + tools.append(block) + print(len(tools), "- tools") + print(len(tools) / len(session)) diff --git a/src/benchmark/tool_plan_benchmarking/run.py b/src/benchmark/tool_plan_benchmarking/run.py new file mode 100644 index 0000000..6aa5a8d --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/run.py @@ -0,0 +1,495 @@ +import argparse +import json +import logging +import os +from pathlib import Path + +import tiktoken +from jinja2 import Environment, FileSystemLoader +from load_dotenv import load_dotenv + +from src.algorithms.dialogue import Dialogue +from src.algorithms.simple_algorithms.dialog_short_tools import DialogueWithShortTools +from src.algorithms.simple_algorithms.dialog_with_weights import DialogueWithWeights +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + OpenAIModels, + Session, +) +from src.algorithms.summarize_algorithms.memory_bank.dialogue_system import ( + MemoryBankDialogueSystem, +) +from src.algorithms.summarize_algorithms.recsum.dialogue_system import ( + RecsumDialogueSystem, +) +from src.benchmark.logger.baseline_logger import BaselineLogger +from src.benchmark.logger.memory_logger import MemoryLogger +from src.benchmark.models.dtos import ( + MODEL_PRICES, + BaseRecord, + DividedSession, + MemoryRecord, + TokenInfo, +) +from src.benchmark.models.enums import AlgorithmName, MetricType, AlgorithmDirectory +from src.benchmark.tool_plan_benchmarking.evaluators.f1_tool_evaluator import ( + F1ToolEvaluator, +) +from src.benchmark.tool_plan_benchmarking.graphs.general_trends import GeneralTrends +from src.benchmark.tool_plan_benchmarking.graphs.graph_builder import GraphBuilder +from src.benchmark.tool_plan_benchmarking.load_session import Loader +from src.benchmark.tool_plan_benchmarking.statistics.dtos import ( + AlgorithmStatistics, + StatisticsDto, +) +from src.benchmark.tool_plan_benchmarking.statistics.statistics import Statistics +from src.utils.configure_logs import configure_logs + +load_dotenv() + +BASE_DATA_PATH = os.getenv("BASE_DATA_PATH", "") +LOGS_PATH = os.getenv("LOGS_PATH", Path(__file__).resolve().parent / "logs" / "memory") + +JSON_FILE_TEMPLATE: str = "*.json" + + +class Runner: + """ + Orchestrates tool-metrics benchmark runs. + + Responsibilities: + - loads past sessions + a "gold" session from `BASE_DATA_PATH` + - runs a selected dialogue system/baseline + - evaluates outputs via `BaseEvaluator` implementations (e.g. `F1ToolEvaluator`) + - optionally builds graphs from saved logs + """ + + def __init__(self, templates_dir: str = "prompts") -> None: + self._logger = logging.getLogger() + self._env = Environment( + loader=FileSystemLoader(templates_dir), + autoescape=True, + trim_blocks=True + ) + + self._baseline_logger = BaselineLogger() + self._memory_logger = MemoryLogger() + + def run(self, name: str) -> None: + """ + Run a full benchmark sweep for a given algorithm name. + + Loads input sessions from `BASE_DATA_PATH`, selects the algorithm/baseline, evaluates it with tool metrics, + and prints aggregated statistics. + + :param name: algorithm name (see `AlgorithmName`). + :return: None + """ + algorithm: Dialogue = Runner.__init_algorithm(AlgorithmName(name)) + + self._logger.info("Start parsing session") + past_interactions: list[Session] = [] + + path_data_type_1: Path = Path(BASE_DATA_PATH) / "data_type_1" + for file in path_data_type_1.glob(JSON_FILE_TEMPLATE): + past_interactions.append( + Loader.load_session_data_type_1(file) + ) + + path_data_type_2: Path = Path(BASE_DATA_PATH) / "data_type_2" + for file in path_data_type_2.glob(JSON_FILE_TEMPLATE): + past_interactions.append( + Loader.load_session_data_type_2(file) + ) + + gold_session: Session = Loader.load_session_data_type_1(Path(BASE_DATA_PATH) / "gold_session.json") + + divided_session: DividedSession = self.__divide_session(gold_session) + reference, session = divided_session.reference, divided_session.past_interactions + + # The algorithms expect `system_prompt` argument to contain the latest user request. + prompt = session[-1].content if len(session) > 0 else "" + + f1_tool_evaluator_strict = F1ToolEvaluator("strict") + f1_tool_evaluator_arguments_similarity = F1ToolEvaluator("arguments_similarity") + f1_tool_evaluator = F1ToolEvaluator() + + for count_of_sessions in [1, 3, 5, 7, 9, 11, 13, 15]: + subdirectory: Path = Path(str(count_of_sessions)) + + if name in ("full_baseline", "short_tools", "weights"): + self._logger.info("Start evaluating full baseline statistics") + statistics: StatisticsDto = Statistics.calculate( + 5, + [algorithm], + [ + f1_tool_evaluator, + f1_tool_evaluator_strict, + f1_tool_evaluator_arguments_similarity + ], + past_interactions, + count_of_sessions, + Session(session), + prompt, + reference, + self._baseline_logger, + subdirectory, + None, + True + ) + elif name == "last_baseline": + self._logger.info("Start evaluating last baseline statistics") + statistics = Statistics.calculate( + 5, + [algorithm], + [ + f1_tool_evaluator, + f1_tool_evaluator_strict, + f1_tool_evaluator_arguments_similarity + ], + [], + count_of_sessions, + Session(session), + prompt, + reference, + self._baseline_logger, + subdirectory, + None, + True + ) + else: + self._logger.info("Start evaluating memory statistics") + statistics = Statistics.calculate( + 5, + [algorithm], + [ + f1_tool_evaluator, + f1_tool_evaluator_strict, + f1_tool_evaluator_arguments_similarity + ], + past_interactions, + count_of_sessions, + Session(session), + prompt, + reference, + self._memory_logger, + subdirectory, + None, + True + ) + + Statistics.print_statistics( + statistics + ) + + def evaluate_by_logs( + self, + name: str, + logs_path: Path | str = LOGS_PATH, + iteration: int | None = None, + ) -> None: + """Append newly-added metrics to existing log JSONs and print updated statistics. + + This mode does *not* re-run the dialogue system. It: + - loads `gold_session.json` from `BASE_DATA_PATH` to build the reference trace + - re-evaluates the latest saved logs under `logs_path` (in-place) + + CLI usage (see `__main__` below): + python -m src.benchmark.tool_plan_benchmarking.run --eval-by-logs [--logs-path PATH] [--iteration N] + """ + algorithm: Dialogue = Runner.__init_algorithm(AlgorithmName(name)) + + gold_session: Session = Loader.load_session_data_type_1(Path(BASE_DATA_PATH) / "gold_session.json") + divided_session: DividedSession = self.__divide_session(gold_session) + reference = divided_session.reference + + f1_tool_evaluator_strict = F1ToolEvaluator("strict") + f1_tool_evaluator_arguments_similarity = F1ToolEvaluator("arguments_similarity") + f1_tool_evaluator = F1ToolEvaluator("nonstrict") + + if name in ("full_baseline", "short_tools", "weights"): + logger = BaselineLogger() + else: + logger = MemoryLogger() + + for count_of_sessions in [1, 3, 5, 7, 9, 11, 13, 15]: + subdirectory: Path = Path(str(count_of_sessions)) + self._logger.info("Start evaluating by logs (fold=%s)", count_of_sessions) + + statistics = Statistics.calculate_with_new_metrics_by_logs( + algorithms=[algorithm], + evaluator_functions=[ + f1_tool_evaluator, + f1_tool_evaluator_strict, + f1_tool_evaluator_arguments_similarity + ], + reference=reference, + logger=logger, + logs_path=logs_path, + subdirectory=subdirectory, + iteration=iteration, + normalize=False, + ) + + Statistics.print_statistics(statistics) + + @staticmethod + def __init_algorithm(name: AlgorithmName) -> Dialogue: + if name.value == "base_recsum": + return RecsumDialogueSystem(embed_code=False, embed_tool=False, system_name="BaseRecsum") + elif name.value == "base_memory_bank": + return MemoryBankDialogueSystem(embed_code=False, embed_tool=False, system_name="BaseMemoryBank") + elif name.value == "rag_recsum": + return RecsumDialogueSystem(embed_code=True, embed_tool=True, system_name="RagRecsum") + elif name.value == "rag_memory_bank": + return MemoryBankDialogueSystem(embed_code=True, embed_tool=True, system_name="RagMemoryBank") + elif name.value == "full_baseline": + return DialogueBaseline("FullBaseline") + elif name.value == "short_tools": + return DialogueWithShortTools("ShortTools") + elif name.value == "weights": + return DialogueWithWeights("Weights") + else: + return DialogueBaseline("LastBaseline") + + @staticmethod + def get_statistics_by_directory_with_logs(path: Path | str) -> StatisticsDto: + records: list[BaseRecord] = [] + for alg_folder in [ + "BaseMemoryBank", + "BaseRecsum", + "FullBaseline", + "LastBaseline", + "RagMemoryBank", + "RagRecsum" + ]: + folder = Path(path) / alg_folder + for path in folder.glob(JSON_FILE_TEMPLATE): + print(path) + with path.open("r", encoding="utf-8") as f: + obj = json.load(f) + if obj.get("memory") is None: + record = BaseRecord.from_dict(obj) + else: + record = MemoryRecord.from_dict(obj) + records.append(record) + return Statistics.calculate_by_logs( + count_of_launches=10, + metrics=records + ) + + @staticmethod + def get_spent_tokens_count(logs: BaseRecord, model: OpenAIModels) -> TokenInfo: + prompt = str(logs.query) + response = str(logs.response) + input_tokens = Runner.__count_tokens(prompt) + output_tokens = Runner.__count_tokens(response) + input_price = input_tokens * MODEL_PRICES[model].input_per_million / 1_000_000 + output_price = output_tokens * MODEL_PRICES[model].output_per_million / 1_000_000 + return TokenInfo( + model=model, + price=MODEL_PRICES[model], + input_tokens=input_tokens, + output_tokens=output_tokens, + input_price=input_price, + output_price=output_price, + total_price=input_price + output_price + ) + + def __divide_session(self, session: Session) -> DividedSession: + past_interactions: list[BaseBlock] = [] + reference: list[BaseBlock] = [] + query: BaseBlock | None = None + is_query_found: bool = False + for i in range(len(session.messages) - 1, -1, -1): + if is_query_found: + past_interactions.append(session.messages[i]) + elif session.messages[i].role == "USER" and session.messages[i].content != "": + self._logger.info(f"User founded {i}") + self._logger.info(f"User message: {session.messages[i].content}") + query = session.messages[i] + is_query_found = True + else: + reference.append(session.messages[i]) + + assert query is not None, "User's query is not founded." + + reference = reference[::-1] + past_interactions = past_interactions[::-1] + past_interactions.append(query) + + return DividedSession( + reference=reference, + past_interactions=past_interactions, + ) + + @staticmethod + def tokens() -> None: + Statistics.print_statistics( + Runner.get_statistics_by_directory_with_logs(LOGS_PATH) + ) + + path = Path(LOGS_PATH) + for directory in [ + "BaseMemoryBank", + "BaseRecsum", + "FullBaseline", + "LastBaseline", + "RagMemoryBank", + "RagRecsum" + ]: + folder = path / directory + ps = [] + for p in folder.glob(JSON_FILE_TEMPLATE): + ps.append(p) + ps.sort() + + for i in range(len(ps)): + print(i) + p = ps[i] + with p.open("r", encoding="utf-8") as f: + d = json.load(f) + if d.get("memory") is None: + record = BaseRecord.from_dict(d) + else: + record = MemoryRecord.from_dict(d) + tokens_info = Runner.get_spent_tokens_count(record, OpenAIModels.GPT_4_O_MINI) + p_n = p.resolve().parent / f"{p.stem}_tokens.json" + with p_n.open("w", encoding="utf-8") as f: + json.dump(tokens_info.to_dict(encode_json=True), f, indent=4) + + @staticmethod + def build_graph( + graph_types: list[type[GraphBuilder]], + directories: list[Path | str], + normalize: bool = False, + ) -> None: + path = Path(LOGS_PATH) + if not path.exists(): + logging.getLogger(__name__).warning( + "LOGS_PATH does not exist: %s. Set env LOGS_PATH or place logs under the default path.", + path, + ) + return + + algs: list[AlgorithmStatistics] = [] + for fold in [1, 3, 5, 7, 9, 11, 13, 15]: + ps: list[Path] = [] + for directory in sorted({str(d) for d in directories}): + folder = path / directory / f"{fold}" + if not folder.exists(): + continue + ps.extend(folder.glob(JSON_FILE_TEMPLATE)) + + ps.sort() + + if not ps: + logging.getLogger(__name__).info( + "No log files found for fold=%s under %s. Skipping fold.", + fold, + path, + ) + continue + + r: list[BaseRecord] = [] + for p in ps: + with p.open("r", encoding="utf8") as f: + j = json.load(f) + if "memory" in j: + data: BaseRecord = MemoryRecord.from_dict(j) + else: + data = BaseRecord.from_dict(j) + r.append(data) + + stats = Statistics.calculate_by_logs(fold, r, normalize=normalize) + algs.extend(stats.algorithms) + Statistics.print_statistics(stats) + + for graph in graph_types: + f1_nonstrict = [alg for alg in algs if alg.metric == MetricType.F1_TOOL] + f1_arguments_similarity = [alg for alg in algs if alg.metric == MetricType.F1_TOOL_ARGUMENTS_SIMILARITY] + f1_strict = [alg for alg in algs if alg.metric == MetricType.F1_TOOL_STRICT] + + graphs_dir = path / "graphs" + graphs_dir.mkdir(parents=True, exist_ok=True) + + graph.build( + StatisticsDto(algorithms=f1_nonstrict), + str(graphs_dir / f"{graph.__name__}_nonstrict.png"), + "nonstrict", + ) + graph.build( + StatisticsDto(algorithms=f1_strict), + str(graphs_dir / f"{graph.__name__}_strict.png"), + "strict", + ) + graph.build( + StatisticsDto(algorithms=f1_arguments_similarity), + str(graphs_dir / f"{graph.__name__}_strict.png"), + "arguments_similarity", + ) + + def __prepare_system_prompt(self) -> str: + template = self._env.get_template("first_stage.j2") + rendered_prompt = template.render() + return rendered_prompt + + @staticmethod + def __count_tokens(text: str) -> int: + encoding = tiktoken.get_encoding("o200k_base") + tokens = encoding.encode(text) + return len(tokens) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Tool-plan benchmarking runner") + parser.add_argument( + "name", + help=f"Algorithm name. Allowed: {[a.value for a in AlgorithmName]}", + ) + parser.add_argument( + "--eval-by-logs", + action="store_true", + help="Do not run algorithms; instead, append newly-added metrics by re-evaluating saved logs in-place.", + ) + parser.add_argument( + "--logs-path", + default=str(LOGS_PATH), + help="Root directory with logs (default: env LOGS_PATH or tool default).", + ) + parser.add_argument( + "--iteration", + type=int, + default=None, + help="If set, evaluates only the newest log file with this iteration value.", + ) + return parser + + +if __name__ == "__main__": + configure_logs(loglevel=logging.INFO) + + args = _build_arg_parser().parse_args() + + runner = Runner() + if args.eval_by_logs: + runner.evaluate_by_logs(args.name, logs_path=args.logs_path, iteration=args.iteration) + else: + runner.run(args.name) + + Runner.build_graph( + [GeneralTrends], + [ + AlgorithmDirectory.FULL_BASELINE.value, + AlgorithmDirectory.LAST_BASELINE.value, + AlgorithmDirectory.RAG_MEMORY_BANK.value, + AlgorithmDirectory.RAG_RECSUM.value, + AlgorithmDirectory.RAG_MEMORY_BANK.value, + AlgorithmDirectory.BASE_MEMORY_BANK.value, + AlgorithmDirectory.BASE_RECSUM.value, + AlgorithmDirectory.WEIGHTS.value, + AlgorithmDirectory.SHORT_TOOLS.value, + ], + normalize=False, + ) diff --git a/src/benchmark/tool_plan_benchmarking/statistics/__init__.py b/src/benchmark/tool_plan_benchmarking/statistics/__init__.py new file mode 100644 index 0000000..5efd7eb --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/__init__.py @@ -0,0 +1,54 @@ +"""Statistics calculation helpers for tool-plan benchmarking. + +IMPORTANT: +This directory name (`statistics`) shadows Python's stdlib module `statistics` when +`src/benchmark/tool_plan_benchmarking/run.py` is executed as a script (Python adds +that directory to `sys.path`). Some third-party libraries (e.g. seaborn) import +`statistics.NormalDist` from the stdlib. + +To avoid breaking those imports, we dynamically load the stdlib `statistics.py` +under an alternate name and re-export the expected symbols. +""" + +from __future__ import annotations + +import importlib.util +import sysconfig + +from pathlib import Path +from types import ModuleType + + +def _load_stdlib_statistics() -> ModuleType: + stdlib_dir = Path(sysconfig.get_paths()["stdlib"]) # e.g. .../lib/python3.12 + statistics_path = stdlib_dir / "statistics.py" + + spec = importlib.util.spec_from_file_location("_stdlib_statistics", statistics_path) + if spec is None or spec.loader is None: + raise ImportError(f"Failed to load stdlib statistics module from {statistics_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_stdlib_statistics = _load_stdlib_statistics() + +# Re-export the subset that external libraries commonly import. +NormalDist = _stdlib_statistics.NormalDist +mean = _stdlib_statistics.mean +median = _stdlib_statistics.median +stdev = _stdlib_statistics.stdev +pstdev = _stdlib_statistics.pstdev +variance = _stdlib_statistics.variance +pvariance = _stdlib_statistics.pvariance + +__all__ = [ + "NormalDist", + "mean", + "median", + "stdev", + "pstdev", + "variance", + "pvariance", +] diff --git a/src/benchmark/tool_plan_benchmarking/statistics/aggregator.py b/src/benchmark/tool_plan_benchmarking/statistics/aggregator.py new file mode 100644 index 0000000..824b032 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/aggregator.py @@ -0,0 +1,64 @@ +import logging + +from collections import defaultdict +from collections.abc import Iterable +from math import fsum + +from src.benchmark.tool_plan_benchmarking.statistics.dtos import ( + AlgorithmRun, + AlgorithmStatistics, + MetricKey, + MetricObservation, + StatisticsDto, +) +from src.benchmark.tool_plan_benchmarking.statistics.normalizer import MetricNormalizer + + +class MetricStatisticsAggregator: + """Aggregates metric observations into `StatisticsDto` (mean/variance + per-run values).""" + + def __init__(self, *, normalize: bool = False, logger: logging.Logger | None = None) -> None: + self._normalize = normalize + self._logger = logger or logging.getLogger() + + def aggregate(self, observations: Iterable[MetricObservation]) -> StatisticsDto: + values_by_key: dict[MetricKey, list[float]] = defaultdict(list) + for observation in observations: + values_by_key[observation.key].append(observation.value) + + if self._normalize: + values_by_key = MetricNormalizer.normalize_by_metric_type(dict(values_by_key)) + + algorithm_stats: list[AlgorithmStatistics] = [] + self._logger.info("Getting statistics...") + + for key, values in values_by_key.items(): + self._logger.info("Getting %s %s statistics", key.algorithm, key.metric.value) + + n = len(values) + if n == 0: + continue + + mean = fsum(values) / n + variance = fsum((v - mean) ** 2 for v in values) / n + + algorithm_stats.append( + AlgorithmStatistics( + name=key.algorithm, + metric=key.metric, + count_of_launches=n, + mean=mean, + variance=variance, + runs=[ + AlgorithmRun( + algorithm=key.algorithm, + metric=key.metric, + value=value, + sessions=key.session_count, + ) + for value in values + ], + ) + ) + + return StatisticsDto(algorithms=algorithm_stats) diff --git a/src/benchmark/tool_plan_benchmarking/statistics/dtos.py b/src/benchmark/tool_plan_benchmarking/statistics/dtos.py new file mode 100644 index 0000000..7173804 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/dtos.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass + +from src.benchmark.models.enums import AlgorithmName, MetricType + + +@dataclass(frozen=True, slots=True) +class MetricKey: + """Grouping key for a single metric series. + A series is identified by: + - algorithm name + - metric type + - number of sessions used in the run + """ + algorithm: str + metric: MetricType + session_count: int + + +@dataclass(frozen=True, slots=True) +class MetricObservation: + """A single observed metric value for a given `MetricKey`.""" + key: MetricKey + value: float + + +@dataclass +class AlgorithmRun: + algorithm: str + metric: MetricType + value: float + sessions: int + + +@dataclass +class AlgorithmStatistics: + name: str + metric: MetricType + count_of_launches: int + mean: float + variance: float + runs: list[AlgorithmRun] + # mode: int | float + + +@dataclass +class MetricValues: + algorithm: AlgorithmName + metric: MetricType + values: list[float] + + +@dataclass +class StatisticsDto: + algorithms: list[AlgorithmStatistics] diff --git a/src/benchmark/tool_plan_benchmarking/statistics/evaluation_runner.py b/src/benchmark/tool_plan_benchmarking/statistics/evaluation_runner.py new file mode 100644 index 0000000..9c33379 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/evaluation_runner.py @@ -0,0 +1,85 @@ +import hashlib +import logging +import random + +from pathlib import Path +from typing import Any + +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session +from src.benchmark.logger.base_logger import BaseLogger +from src.benchmark.models.dtos import BaseRecord +from src.benchmark.tool_plan_benchmarking.calculator import Calculator +from src.benchmark.tool_plan_benchmarking.evaluators.base_evaluator import BaseEvaluator + + +class EvaluationLaunchRunner: + """Runs repeated evaluation launches and returns produced log records.""" + + _RUN_ID = "exp_3_02_2026_8_43_pm" + + def __init__(self, logger: logging.Logger | None = None) -> None: + self._logger = logger or logging.getLogger() + + def run( + self, + launch_count: int, + algorithms: list[Dialogue], + evaluators: list[BaseEvaluator], + past_sessions: list[Session], + session_count: int, + gold_session: Session, + prompt: str, + reference: list[BaseBlock], + results_logger: BaseLogger, + subdirectory: Path, + tools: list[dict[str, Any]] | None = None, + shuffle: bool = False, + ) -> list[BaseRecord]: + records: list[BaseRecord] = [] + for launch_index in range(launch_count): + prepared_sessions = self._prepare_sessions( + past_sessions=past_sessions, + session_count=session_count, + gold_session=gold_session, + shuffle=shuffle, + seed=self._make_seed(launch_count, launch_index), + ) + + self._logger.info("Starting evaluation launch %s", launch_index) + records.extend( + Calculator.evaluate( + algorithms, + evaluators, + prepared_sessions, + prompt, + reference, + results_logger, + subdirectory, + tools, + launch_index, + ) + ) + + return records + + @staticmethod + def _prepare_sessions( + past_sessions: list[Session], + session_count: int, + gold_session: Session, + shuffle: bool, + seed: int, + ) -> list[Session]: + sessions = past_sessions.copy() + if len(sessions) > 1 and shuffle: + random.Random(seed).shuffle(sessions) + + prepared_sessions = sessions[: max(session_count - 1, 0)] + prepared_sessions.append(gold_session) + return prepared_sessions + + @classmethod + def _make_seed(cls, total_launches: int, launch_index: int) -> int: + digest = hashlib.sha256(f"{cls._RUN_ID}:{total_launches}:{launch_index}".encode()).hexdigest() + return int(digest[:16], 16) diff --git a/src/benchmark/tool_plan_benchmarking/statistics/normalizer.py b/src/benchmark/tool_plan_benchmarking/statistics/normalizer.py new file mode 100644 index 0000000..30bf542 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/normalizer.py @@ -0,0 +1,37 @@ +from collections import defaultdict + +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.statistics.dtos import MetricKey + + +class MetricNormalizer: + """Normalizes metric series across algorithms. + + Normalization is done *per metric type* (e.g. F1_TOOL, F1_TOOL_STRICT) using global min/max across all + algorithms and session counts for that metric. + """ + + @staticmethod + def normalize_by_metric_type(values_by_key: dict[MetricKey, list[float]]) -> dict[MetricKey, list[float]]: + values_by_metric: dict[MetricType, list[float]] = defaultdict(list) + for key, values in values_by_key.items(): + values_by_metric[key.metric].extend(values) + + if not values_by_metric: + return values_by_key + + min_max_by_metric: dict[MetricType, tuple[float, float]] = {} + for metric, values in values_by_metric.items(): + if not values: + continue + min_max_by_metric[metric] = (min(values), max(values)) + + normalized: dict[MetricKey, list[float]] = {} + for key, values in values_by_key.items(): + global_min, global_max = min_max_by_metric.get(key.metric, (0.0, 0.0)) + if global_max == global_min: + normalized[key] = values + continue + normalized[key] = [(v - global_min) / (global_max - global_min) for v in values] + + return normalized diff --git a/src/benchmark/tool_plan_benchmarking/statistics/observations_collector.py b/src/benchmark/tool_plan_benchmarking/statistics/observations_collector.py new file mode 100644 index 0000000..1b5365a --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/observations_collector.py @@ -0,0 +1,33 @@ +from collections.abc import Iterable + +from src.benchmark.models.dtos import BaseRecord +from src.benchmark.tool_plan_benchmarking.statistics.dtos import ( + MetricKey, + MetricObservation, +) + + +class MetricObservationsCollector: + """Extracts metric observations from benchmark log records.""" + + @staticmethod + def collect(records: Iterable[BaseRecord]) -> list[MetricObservation]: + observations: list[MetricObservation] = [] + for record in records: + if record.metric is None: + continue + + session_count = len(record.sessions) + for metric_state in record.metric: + observations.append( + MetricObservation( + key=MetricKey( + algorithm=record.system, + metric=metric_state.metric_name, + session_count=session_count, + ), + value=float(metric_state.metric_value), + ) + ) + + return observations diff --git a/src/benchmark/tool_plan_benchmarking/statistics/printer.py b/src/benchmark/tool_plan_benchmarking/statistics/printer.py new file mode 100644 index 0000000..a3fa828 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/printer.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from src.benchmark.tool_plan_benchmarking.statistics.dtos import StatisticsDto + + +class StatisticsPrinter: + """Human-friendly console output for `StatisticsDto`.""" + + @staticmethod + def print(stats: StatisticsDto) -> None: + print("=== Statistics ===") + for alg_stat in stats.algorithms: + print(f"Algorithm: {alg_stat.name}") + print(f" Metric: {alg_stat.metric.value}") + print(f" Count of launches: {alg_stat.count_of_launches}") + print(f" Math. expectation: {alg_stat.mean:.4f}") + print(f" Variance: {alg_stat.variance:.4f}") + print() diff --git a/src/benchmark/tool_plan_benchmarking/statistics/statistics.py b/src/benchmark/tool_plan_benchmarking/statistics/statistics.py new file mode 100644 index 0000000..bf19764 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/statistics/statistics.py @@ -0,0 +1,107 @@ +import logging + +from pathlib import Path +from typing import Any + +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session +from src.benchmark.logger.base_logger import BaseLogger +from src.benchmark.models.dtos import BaseRecord +from src.benchmark.tool_plan_benchmarking.calculator import Calculator +from src.benchmark.tool_plan_benchmarking.evaluators.base_evaluator import BaseEvaluator +from src.benchmark.tool_plan_benchmarking.statistics.aggregator import MetricStatisticsAggregator +from src.benchmark.tool_plan_benchmarking.statistics.dtos import StatisticsDto +from src.benchmark.tool_plan_benchmarking.statistics.evaluation_runner import EvaluationLaunchRunner +from src.benchmark.tool_plan_benchmarking.statistics.observations_collector import MetricObservationsCollector +from src.benchmark.tool_plan_benchmarking.statistics.printer import StatisticsPrinter + + +class Statistics: + """Facade used by existing call sites (`run.py`, etc.).""" + + @staticmethod + def calculate( + count_of_launches: int, + algorithms: list[Dialogue], + evaluator_functions: list[BaseEvaluator], + sessions: list[Session], + count_of_sessions: int, + gold_session: Session, + prompt: str, + reference: list[BaseBlock], + logger: BaseLogger, + subdirectory: Path, + tools: list[dict[str, Any]] | None = None, + shuffle: bool = False, + ) -> StatisticsDto: + system_logger = logging.getLogger() + + records = EvaluationLaunchRunner(system_logger).run( + launch_count=count_of_launches, + algorithms=algorithms, + evaluators=evaluator_functions, + past_sessions=sessions, + session_count=count_of_sessions, + gold_session=gold_session, + prompt=prompt, + reference=reference, + results_logger=logger, + subdirectory=subdirectory, + tools=tools, + shuffle=shuffle, + ) + + observations = MetricObservationsCollector.collect(records) + return MetricStatisticsAggregator(normalize=False, logger=system_logger).aggregate(observations) + + @staticmethod + def calculate_by_logs( + count_of_launches: int, + metrics: list[BaseRecord], + system_logger: logging.Logger | None = None, + normalize: bool = False, + ) -> StatisticsDto: + _ = count_of_launches # kept for backward-compatible signature + system_logger = system_logger or logging.getLogger() + + observations = MetricObservationsCollector.collect(metrics) + return MetricStatisticsAggregator(normalize=normalize, logger=system_logger).aggregate(observations) + + @staticmethod + def calculate_with_new_metrics_by_logs( + algorithms: list[Dialogue], + evaluator_functions: list[BaseEvaluator], + reference: list[BaseBlock], + logger: BaseLogger, + logs_path: Path | str, + subdirectory: Path, + iteration: int | None = None, + system_logger: logging.Logger | None = None, + normalize: bool = False, + ) -> StatisticsDto: + """Re-evaluate the last saved logs with new evaluators and return aggregated statistics. + + This is a convenience wrapper around: + - `Calculator.evaluate_by_logs()` to append new metrics into existing JSON log files + - `Statistics.calculate_by_logs()` to aggregate the resulting metric observations + + It updates files in-place under: `//`. + """ + system_logger = system_logger or logging.getLogger() + + records = Calculator.evaluate_by_logs( + algorithms=algorithms, + evaluator_functions=evaluator_functions, + reference=reference, + logger=logger, + logs_path=logs_path, + subdirectory=subdirectory, + iteration=iteration, + ) + + observations = MetricObservationsCollector.collect(records) + return MetricStatisticsAggregator(normalize=normalize, logger=system_logger).aggregate(observations) + + @staticmethod + def print_statistics(stats: StatisticsDto) -> None: + StatisticsPrinter.print(stats) diff --git a/src/benchmark/tool_plan_benchmarking/tools_and_schemas/__init__.py b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/tool_plan_benchmarking/tools_and_schemas/output_schema.json b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/output_schema.json new file mode 100644 index 0000000..f01fc31 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/output_schema.json @@ -0,0 +1,68 @@ +{ + "title": "ActionPlan", + "description": "Schema for the structured action plan output from the simple_algorithms", + "type": "object", + "additionalProperties": false, + "required": ["version", "user_request", "memory_used", "assumptions", "plan_steps"], + "properties": { + "version": { "type": "string", "enum": ["1.0"] }, + "user_request": { "type": "string" }, + "context_summary": { "type": "string" }, + "memory_used": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": false, + "required": ["source", "key", "excerpt"], + "properties": { + "source": { "type": "string", "enum": ["session", "memory_bank", "tool_result", "note"] }, + "key": { "type": "string" }, + "excerpt": { "type": "string" } + } + } + }, + "assumptions": { + "type": "array", + "items": { "type": "string" } + }, + "risks": { + "type": "array", + "items": { "type": "string" } + }, + "plan_steps": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "additionalProperties": false, + "required": ["id", "kind", "name", "description", "depends_on"], + "properties": { + "id": { "type": "string", "pattern": "^s[0-9]+$" }, + "kind": { "type": "string", "enum": ["tool_call", "final_answer"] }, + "name": { "type": "string" }, + "description": { "type": "string" }, + "depends_on": { + "type": "array", + "items": { "type": "string", "pattern": "^s[0-9]+$" } + }, + "condition": { "type": "string" }, + "expected": { "type": "string" }, + "args": { + "type": "object", + "description": "For kind=tool_call only. Omit for other kinds.", + "additionalProperties": true + }, + "answer_template": { + "type": "string", + "description": "For kind=final_answer only. Omit otherwise." + }, + "collect_from": { + "type": "array", + "description": "For kind=final_answer: step ids whose outputs feed the answer.", + "items": { "type": "string", "pattern": "^s[0-9]+$" } + } + } + } + } + } +} \ No newline at end of file diff --git a/src/benchmark/tool_plan_benchmarking/tools_and_schemas/parsed_jsons.py b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/parsed_jsons.py new file mode 100644 index 0000000..c47d3b8 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/parsed_jsons.py @@ -0,0 +1,13 @@ +import json +import os + +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv + +load_dotenv() +tools_and_schemas_path = Path(os.getenv("TOOLS_AND_SCHEMAS_PATH")) + +PLAN_SCHEMA: dict[str, Any] = json.loads((tools_and_schemas_path / "output_schema.json").read_text(encoding="utf-8")) +TOOLS: list[dict[str, Any]] = json.loads((tools_and_schemas_path / "tools.json").read_text(encoding="utf-8")) diff --git a/src/benchmark/tool_plan_benchmarking/tools_and_schemas/tools.json b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/tools.json new file mode 100644 index 0000000..5ed5707 --- /dev/null +++ b/src/benchmark/tool_plan_benchmarking/tools_and_schemas/tools.json @@ -0,0 +1,2359 @@ +[ + { + "name": "list_dir", + "description": "List all files and directories in the specified project folder. \nThe listing is performed recursively up to the specified `depth`. \nThe overall result set is capped at 100 entries.\nIn practice, use a depth of 4 or less to avoid overly long outputs. \nUse this tool to explore the project structure.", + "arguments_schema": { + "type": "object", + "properties": { + "directory_path": { + "type": "string", + "description": "Path to the directory to list, relative to the project root." + }, + "depth": { + "type": "integer", + "description": "Recursion depth. 0 lists only the directory’s immediate children." + } + }, + "additionalProperties": false, + "required": [ + "directory_path", + "depth" + ] + } + }, + { + "name": "search_file_by_name", + "description": "\n Recursively search for project files whose filenames match a given glob pattern. \n The search is performed within `search_directory`, which must be inside the project (use `.` to search the entire project). \n Returns up to 50 matching file paths, relative to the project root. \n Note that this tool only searches files within the project directory, excluding libraries and external dependencies.\n ", + "arguments_schema": { + "type": "object", + "properties": { + "glob_pattern": { + "type": "string", + "description": "Glob pattern that matches the filename only (path not included)." + }, + "search_directory": { + "type": "string", + "description": "Path to the directory to search in, relative to the project root." + } + }, + "additionalProperties": false, + "required": [ + "glob_pattern", + "search_directory" + ] + } + }, + { + "name": "search_for_text", + "description": "Search for a literal text snippet using the IDE’s search engine. You can search within a single file or a project folder.\nTo search the entire project, set `target_path` to `.`.\nReturns up to 50 matches, each consisting of a file path (relative to the project root), the line content, and a 1-indexed line and column number.\nYou can choose whether the search is case-sensitive.\nMultiline text snippets are supported - pay attention to indentation.\nUse this tool to determine precise positions before invoking other tools that require them.\nCall this tool in one message as many times as you need: it is more cost-efficient than calling it in consecutive messages.", + "arguments_schema": { + "type": "object", + "properties": { + "target_path": { + "type": "string", + "description": "Path to the file or directory to search in, relative to the project root." + }, + "text_snippet": { + "type": "string", + "description": "Literal text to search for. Can be multiline." + }, + "is_case_sensitive": { + "type": "boolean", + "description": "Whether the search is case-sensitive." + } + }, + "additionalProperties": false, + "required": [ + "target_path", + "text_snippet", + "is_case_sensitive" + ] + } + }, + { + "name": "read_file", + "description": "Read the content of a file. The output includes the total number of lines in the file and the text of all lines between the 1-indexed `start_line` and `end_line` numbers (inclusive). \nThe number of lines returned cannot exceed 500. If you need more context, call the tool again to read another range.\nIt is much more cost-efficient to read a larger range once than reading two small ranges twice.", + "arguments_schema": { + "type": "object", + "properties": { + "target_file": { + "type": "string", + "description": "Path to the file to read, relative to the project root directory, or url to a library file." + }, + "start_line": { + "type": "integer", + "description": "The index (starting from 1) of the first line to read." + }, + "end_line": { + "type": "integer", + "description": "The index (starting from 1) of the last line to read. It must be no more than 500 lines away from `start_line`." + } + }, + "additionalProperties": false, + "required": [ + "target_file", + "start_line", + "end_line" + ] + } + }, + { + "name": "similar_search", + "description": "Find snippets of code from the codebase and available dependencies' sources most relevant to the search query using BM25.\nThis performs best when the search query is more precise and relating to the function or purpose of code.\nAfter the search query you must also add a list of relevant keywords. The list may be redundant, it should cover all possible concepts relevant to the search topic.\nThe keywords will be appended to the query, so don't use the same word multiple times and don't use words that already appear in the query.\nAny keyword must be a single word consisting of chars and digits only, without any delimiters.\nKeep the search query itself as a concise specific request, and use the keywords to suggest as many things as possible that can appear in relevant files.\nUse `read_file` with the same path or url and node name to view the full code contents for any item.\nDon't try to get more than 20 results, as this will not work well.\nReturns a list any element of which can be either a file path relative to the project root, or a url to a source file in dependencies.\nEach element of the result is intentionally surrounded with backsticks (`), so preserve backsticks (`) always, don't erase them. ", + "arguments_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query" + }, + "keywords": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Keywords to append to the search query" + }, + "snippets_to_show": { + "type": "integer", + "description": "Number of snippets to return, must not exceed 20" + } + }, + "additionalProperties": false, + "required": [ + "query", + "keywords", + "snippets_to_show" + ] + } + }, + { + "name": "find_java_class_source", + "description": "Find Java/Kotlin (and other UAST-backed) class by name and return metadata only (no source text).\nAccepts either a short class name (e.g. \"List\") or a fully-qualified name (e.g. \"java.util.List\").\nSearches across the project and dependencies (attached sources and decompiled classes).\nIf multiple classes match, returns a paged list of matches without text. If exactly one match is found, returns detailed metadata only.\nTo read the file content, call the read_file tool with the returned file_url (VirtualFile URL).\nYou can filter by scope (project/libraries/all), by source kind (source/decompiled/any), and optionally restrict by a package prefix.\nThe page size is capped at 50.", + "arguments_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Class name to search for. Can be a short name (e.g. 'List') or a fully-qualified name (e.g. 'java.util.List')." + }, + "qualified_name": { + "type": [ + "string", + "null" + ], + "description": "Optional fully-qualified class name to select an exact class directly (e.g. 'com.example.Foo'). When provided, the search will target this class." + }, + "scope": { + "type": "string", + "enum": [ + "project", + "libraries", + "all" + ], + "description": "Search scope: 'project' to search only project content, 'libraries' to search only dependencies, 'all' to search everywhere." + }, + "kind_filter": { + "type": "string", + "enum": [ + "source", + "decompiled", + "any" + ], + "description": "Filter by source kind: 'source' (project or library sources only), 'decompiled' (library classes without sources), or 'any'." + }, + "package_prefix": { + "type": [ + "string", + "null" + ], + "description": "Optional package prefix to narrow results by qualified name (e.g. 'com.example'). Applied to qualified names only." + }, + "page": { + "type": "integer", + "description": "1-based page number for paging through matches. Defaults to 1." + }, + "page_size": { + "type": "integer", + "description": "Number of items per page (max 50). Defaults to 50." + } + }, + "additionalProperties": false, + "required": [ + "query", + "qualified_name", + "scope", + "kind_filter", + "package_prefix", + "page", + "page_size" + ] + } + }, + { + "name": "create_file", + "description": "Create a new file with the specified content. \nYou cannot use this tool to overwrite or edit existing files. \nIf you want to edit an existing file, use `edit_file` instead.\nThe tool returns a list of compilation errors in the created file.", + "arguments_schema": { + "type": "object", + "properties": { + "target_file": { + "type": "string", + "description": "Path to the new file, relative to the project root directory." + }, + "content": { + "type": "string", + "description": "The contents to write to the file. Can be an empty string to create an empty file." + } + }, + "additionalProperties": false, + "required": [ + "target_file", + "content" + ] + } + }, + { + "name": "edit_file", + "description": "Make line-based edits to a project text file.\nEach edit replaces exact line sequences with new content.\nEdits cannot overlap.\nTo make just one edit you still have to pass a list containing 1 element.\nEmpty `old_text` in any of the edits is not allowed (except for empty files). If you want to insert text, you must still provide some lines above or below by appending them both to `old_text` and `new_text`.\nFor each edit, the tool checks that `old_text` appears exactly once in the file. Make sure to give enough context to specify the exact location of each edit with no ambiguity.\nCRITICAL: text is removed by **lines**, so even if `old_text` starts or ends in the middle of a line, the **whole** line will be removed. Your `new_text` should thus contain only full lines too.\nReturns a list of compilation errors after applying your edits.", + "arguments_schema": { + "type": "object", + "properties": { + "target_file": { + "type": "string", + "description": "Path to the file to edit, relative to the project root directory." + }, + "edits": { + "type": "array", + "items": { + "type": "object", + "properties": { + "old_text": { + "type": "string", + "description": "Text to search for - must match exactly." + }, + "new_text": { + "type": "string", + "description": "Text to replace with." + } + }, + "additionalProperties": false, + "required": [ + "old_text", + "new_text" + ] + }, + "description": "List of edits to apply to the file." + } + }, + "additionalProperties": false, + "required": [ + "target_file", + "edits" + ] + } + }, + { + "name": "run_ide_file_action", + "description": "Run a file-level IDE action. Available actions are:\n1. `Quick Fix` - apply all non-interactive quick fixes available in a source file. Use this on broken code before attempting to fix errors manually.\n2. `Reformat Code` - automatically format a source file.\n3. `Rename File` - rename a file and all its usages. You must specify a new file name using the `new_file_name` parameter. If the tool call is reported as successful, you **MUST NEVER** manually search the usages, because it would be just a waste of resources.\n", + "arguments_schema": { + "type": "object", + "properties": { + "action_name": { + "type": "string", + "description": "IDE action to perform." + }, + "target_file": { + "type": "string", + "description": "Path to the file on which to perform the action, relative to the project root." + }, + "new_file_name": { + "type": "string", + "description": "New file name for the `Rename File` action. Give an empty string for any other action." + } + }, + "additionalProperties": false, + "required": [ + "action_name", + "target_file", + "new_file_name" + ] + } + }, + { + "name": "read_terminal", + "description": "Read the output of the most recent terminal command.\nThe output includes all lines between the 1-indexed `start_line` and `end_line` numbers (inclusive). \nThe number of lines returned cannot exceed 500.\nAdditionally, the total output size cannot exceed 2000 characters.\nIf you need more context, call the tool again to read another range.", + "arguments_schema": { + "type": "object", + "properties": { + "start_line": { + "type": "integer", + "description": "The index (starting from 1) of the first line to read." + }, + "end_line": { + "type": "integer", + "description": "The index (starting from 1) of the last line to read. It must be no more than 500 lines away from `start_line`." + }, + "terminal_session_id": { + "type": "string", + "description": "Terminal session identifier" + } + }, + "additionalProperties": false, + "required": [ + "start_line", + "end_line", + "terminal_session_id" + ] + } + }, + { + "name": "run_command", + "description": "Run a shell command in the user’s terminal.\nYou may set `safe_to_run` to true only if the command is fully reversible, read-only, does not download or install software, and does not require elevated privileges.\nWhen `safe_to_run` is true, the command executes immediately without user confirmation.\nFor any command that could potentially be harmful or irreversible (e.g., `sudo`, `rm`, `mv`, `curl`, `wget`, package installers, etc.) always set `safe_to_run` to false, even if the user requests otherwise.\nNever prepend a standalone `cd`, use `working_directory` instead.\nIf the command is long-running or you do not need to see its output, set `is_background` to true.\nOutput is truncated to the first and last 25 lines.\nAdditionally, the tool response will not exceed 2000 characters.\nUser will not be able to interact with the command execution, so pass non-interaction flags (like `--yes`) when needed.\nIf a command relies on paging, append `| cat` to it.\nAvoid using this tool to run `ls` or `find` commands, prefer using `search_file_by_name` or `list_dir` tools.", + "arguments_schema": { + "type": "object", + "properties": { + "working_directory": { + "type": "string", + "description": "Path to the working directory from which to execute the command, relative to the project root directory." + }, + "command": { + "type": "string", + "description": "The exact shell command string to execute." + }, + "safe_to_run": { + "type": "boolean", + "description": "Whether it is completely safe to run the command. If true, the command will be executed without user confirmation." + }, + "is_background": { + "type": "boolean", + "description": "If true, the command will be run in the background while you continue interacting with the user." + } + }, + "additionalProperties": false, + "required": [ + "working_directory", + "command", + "safe_to_run", + "is_background" + ] + } + }, + { + "name": "get_static_ide_analysis", + "description": "List static analysis problems reported by the IDE.\nYou can list them for a single source file or recursively for all files within a project folder.\nThe output is a list of problems for each file (up to 100 in total), each with a description and a 1-indexed line and column number.\nThe tool queries the IDE’s in-editor inspector; it does not run the code.\nThe `severity` parameter lets you get either only errors or both errors and warnings.", + "arguments_schema": { + "type": "object", + "properties": { + "target_path": { + "type": "string", + "description": "Path to the source file or directory, relative to the project root." + }, + "severity": { + "type": "string", + "enum": [ + "ERROR", + "WARNING" + ], + "description": "Lowest severity to include." + } + }, + "additionalProperties": false, + "required": [ + "target_path", + "severity" + ] + } + }, + { + "name": "run_tests", + "description": "Run tests in the specified test class.\nExecute tests directly within the IDE and return results such as compacted exception stack traces.\nThe tool supports ability to collect test coverage for specified files. You can use `files_to_collect_coverage_for` parameter for it. Only collect test coverage when USER explicitly ask you to analyze or increase line/branch test coverage.\nThe tool can record a call tree dump which is very useful for debugging; use the `call_tree_dump_settings` parameter for it. Always include it by default, and set it to `null` only if a previous test run failed due to instrumentation.\nIf the test method contains compilation errors, the test is not executed.\nFor richer feedback and easier integration, it is recommended to use this tool instead of running tests via build system.", + "arguments_schema": { + "type": "object", + "properties": { + "test_class_name": { + "type": "string", + "description": "Fully qualified name of the test class. For example, `org.example.MyClassTest`." + }, + "test_method_names": { + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + }, + "description": "Names of test methods or `null` if all methods in the given test class must be run." + }, + "files_to_collect_coverage_for": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Paths to files to collect test coverage for, relative to the project root." + }, + "call_tree_dump_settings": { + "type": [ + "object", + "null" + ], + "description": "Settings used for recording call tree dumps.", + "properties": { + "base_package": { + "type": "string", + "description": "Package that should be traced, only call-sites in this package and its subpackages will be considered. The best value is almost always the base package of the organization (two dot-separated identifiers). For example, \"org.example\" or \"com.company\"." + }, + "method_of_interest_name": { + "type": "string", + "description": "Fully qualified name of the method whose call sub-tree should be expanded. The best choice is almost always either the method under test or test method. For example, `com.compony.service.impl.MyServiceImpl.myMethod`, `org.springframework.test.web.servlet.MockMvc.perform`, or `org.organization.service.impl.MyServiceTest.myTest`." + }, + "to_string_implementation": { + "type": "string", + "enum": [ + "REAL_TO_STRING", + "POORLY_IMITATED_TO_STRING" + ], + "description": "Implementation of toString() to use in dumps values. Recommended value is `REAL_TO_STRING`. In very rare cases `REAL_TO_STRING` can alter test behavior; set to `POORLY_IMITATED_TO_STRING` only if `REAL_TO_STRING` appears to alter test behaviour." + } + }, + "additionalProperties": false, + "required": [ + "base_package", + "method_of_interest_name", + "to_string_implementation" + ] + } + }, + "additionalProperties": false, + "required": [ + "test_class_name", + "test_method_names", + "files_to_collect_coverage_for", + "call_tree_dump_settings" + ] + } + }, + { + "name": "run_build_system", + "description": "Execute tasks using build system. This tool is specifically designed for running Gradle tasks or Maven goals.\nAvoid using `run_command` tool to run such commands.\nOutput is truncated to the first and last 35 lines.\nAdditionally, the tool response will not exceed 2000 characters.", + "arguments_schema": { + "type": "object", + "properties": { + "build_system_name": { + "type": "string", + "description": "Name of the build system. This parameter can only be set to \"Gradle\" or \"Maven\"." + }, + "working_directory": { + "type": "string", + "description": "Path to the working directory from which to execute the build system tasks, relative to the project root directory." + }, + "tasks": { + "type": "string", + "description": "A string containing tasks for build system (Gradle tasks or Maven goals), separated by a space" + } + }, + "additionalProperties": false, + "required": [ + "build_system_name", + "working_directory", + "tasks" + ] + } + }, + { + "name": "web_search", + "description": "Search the web for information using multiple specialized providers.\n\n# PROVIDERS:\n- StackOverflowSearchProvider: Searches for programming-related questions and answers\n- WikipediaSearchProvider: Searches for Wikipedia articles\n- GitHubSearchProvider: Searches repositories, issues, and users on GitHub\n\nWhen using GitHub provider you must set `search_target` parameter to specify what you're searching for:\n- REPOSITORIES - for code repositories/projects\n- USERS - for GitHub users and organizations \n- ISSUES - for issues and pull requests\n\nNote: For USERS search target, the provider automatically fetches detailed user information including bio, company, location, and follower counts for each result.\n\n# IMPORTANT GUIDELINES:\n- Use meaningful search keywords that describe what you're looking for\n- You must not repeat keywords in the same search query, it does not help you get more results\n- You can use filters in the query string to filter the results for StackOverflow and GitHub\n- Use `use_providers` parameter to optimize routing to relevant providers\n\n# USAGE EXAMPLES:\nExample 1 - Get a general overview of Kotlin coroutines (use all providers):\n{\n \"query\": \"kotlin coroutines overview\",\n \"topK\": 10,\n \"lang\": \"en\",\n \"sort\": \"RELEVANCE\"\n}\n\nExample 2 - Find popular Kotlin GitHub repositories:\n{\n \"query\": \"stars:>1000 language:kotlin\",\n \"search_target\": \"REPOSITORIES\",\n \"topK\": 10,\n \"sort\": \"RELEVANCE\",\n \"use_providers\": [\"GitHubSearchProvider\"]\n}\n\nExample 3 - Find users with many followers on GitHub:\n{\n \"query\": \"kotlin developer followers:>5000 type:user\",\n \"search_target\": \"USERS\",\n \"topK\": 5,\n \"use_providers\": [\"GitHubSearchProvider\"]\n}\n\nExample 4 - Find open bugs in a specific GitHub repository:\n{\n \"query\": \"repo:jetbrains/kotlin state:open label:bug\",\n \"search_target\": \"ISSUES\",\n \"topK\": 20,\n \"sort\": \"RECENCY\",\n \"use_providers\": [\"GitHubSearchProvider\"]\n}\n\nExample 5 - Search StackOverflow for Kotlin coroutines questions:\n{\n \"query\": \"coroutines channel [kotlin]\",\n \"topK\": 5,\n \"sort\": \"RECENCY\",\n \"use_providers\": [\"StackOverflowSearchProvider\"]\n}\n\n# COMMON MISTAKES TO AVOID:\nWRONG: query=\"GitHub users with more than 5000 followers\"\nRIGHT: query=\"developer followers:>5000\", search_target=\"USERS\", \"use_providers\": [\"GitHubSearchProvider\"]\n\nWRONG: query=\"kotlin kotlin language kotlin documentation\"\nRIGHT: query=\"kotlin language documentation\"\n\nWRONG: search_target not specified for GitHub provider\nRIGHT: Always specify search_target (REPOSITORIES, USERS, or ISSUES)\n\nWRONG: Using natural language descriptions in query\nRIGHT: Use keywords + specific provider queries", + "arguments_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (1-1000 characters). Can include provider-specific filters." + }, + "top_k": { + "type": "integer", + "description": "Maximum number of results to return (1-100)" + }, + "sort": { + "type": "string", + "enum": [ + "relevance", + "recency" + ], + "description": "Sorting method for the results, must be either 'relevance' (default) or 'recency'" + }, + "freshness_days": { + "type": [ + "integer", + "null" + ], + "description": "Boost results that are more recent, in days (1-3650 days, ~10 years max)" + }, + "lang": { + "type": "string", + "description": "Preferred language for the results (2-5 characters, format: 'en' or 'en-US')" + }, + "use_providers": { + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + }, + "description": "Use only these provider names (exact match) or null to use all available providers" + }, + "search_target": { + "type": [ + "string", + "null" + ], + "enum": [ + "REPOSITORIES", + "ISSUES", + "USERS" + ], + "description": "CRITICAL: Choose based on what you want to find - REPOSITORIES for repos, USERS for people/orgs, ISSUES for bugs/PRs, applied for GitHub provider only" + } + }, + "additionalProperties": false, + "required": [ + "query", + "top_k", + "sort", + "freshness_days", + "lang", + "use_providers", + "search_target" + ] + } + }, + { + "name": "web_fetch", + "description": "Fetch and extract content from web pages and PDFs.\n\nCapabilities:\n- Extracts clean, structured content from HTML pages using advanced content detection\n- Processes PDF documents with text extraction and structure preservation\n- Removes tracking parameters and normalizes URLs\n- Preserves headings, code blocks, tables, and document structure\n- Returns content as Markdown with rich metadata\n\nUse this tool when you need to read and process the actual content of web pages or PDFs.\nThe tool intelligently removes boilerplate content (navigation, ads, footers) and focuses on the main document content.", + "arguments_schema": { + "type": "object", + "properties": { + "urls": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of URLs to fetch and process, up to 10 URLs" + }, + "timeout": { + "type": [ + "integer", + "null" + ], + "description": "Optional timeout in seconds (default: 10)" + } + }, + "additionalProperties": false, + "required": [ + "urls", + "timeout" + ] + } + }, + { + "name": "ask_user_with_options", + "description": "Ask the user a question and provide up to 10 predefined answer options.\nYou can specify whether the user may select only one choice or multiple.\nImportant: this tool must be called alone; it cannot appear in the same assistant message together with any other tool calls.\nNote that when multiple selection is allowed, the user may also choose none of the options.", + "arguments_schema": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "Question to present to the user. Keep it short, 1-2 sentences." + }, + "options": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of answer options to present to the user. Keep each option short, no longer than 1 simple sentence." + }, + "is_multiple_choice": { + "type": "boolean", + "description": "If true, the user may select multiple choices; if false, the user must pick exactly one." + } + }, + "additionalProperties": false, + "required": [ + "question", + "options", + "is_multiple_choice" + ] + } + }, + { + "name": "add_comment_to_pending_review", + "description": "Add review comment to the requester's latest pending pull request review. A pending review needs to already exist to call this (check with the user if not sure).\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "body": { + "description": "The text of the review comment", + "type": "string" + }, + "line": { + "description": "The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "path": { + "description": "The relative path to the file that necessitates a comment", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "side": { + "description": "The side of the diff to comment on. LEFT indicates the previous state, RIGHT indicates the new state", + "enum": [ + "LEFT", + "RIGHT" + ], + "type": "string" + }, + "startLine": { + "description": "For multi-line comments, the first line of the range that the comment applies to", + "type": "number" + }, + "startSide": { + "description": "For multi-line comments, the starting side of the diff that the comment applies to. LEFT indicates the previous state, RIGHT indicates the new state", + "enum": [ + "LEFT", + "RIGHT" + ], + "type": "string" + }, + "subjectType": { + "description": "The level at which the comment is targeted", + "enum": [ + "FILE", + "LINE" + ], + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumber", + "path", + "body", + "subjectType" + ] + } + }, + { + "name": "add_issue_comment", + "description": "Add a comment to a specific issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "body": { + "description": "Comment content", + "type": "string" + }, + "issue_number": { + "description": "Issue number to comment on", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issue_number", + "body" + ] + } + }, + { + "name": "add_sub_issue", + "description": "Add a sub-issue to a parent issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "The number of the parent issue", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "replace_parent": { + "description": "When true, replaces the sub-issue's current parent issue", + "type": "boolean" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sub_issue_id": { + "description": "The ID of the sub-issue to add. ID is not the same as issue number", + "type": "number" + } + }, + "required": [ + "owner", + "repo", + "issue_number", + "sub_issue_id" + ] + } + }, + { + "name": "assign_copilot_to_issue", + "description": "Assign Copilot to a specific issue in a GitHub repository.\n\nThis tool can help with the following outcomes:\n- a Pull Request created with source code changes to resolve the issue\n\n\nMore information can be found at:\n- https://docs.github.com/en/copilot/using-github-copilot/using-copilot-coding-agent-to-work-on-tasks/about-assigning-tasks-to-copilot\n\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issueNumber": { + "description": "Issue number", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issueNumber" + ] + } + }, + { + "name": "create_branch", + "description": "Create a new branch in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "branch": { + "description": "Name for new branch", + "type": "string" + }, + "from_branch": { + "description": "Source branch (defaults to repo default)", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "branch" + ] + } + }, + { + "name": "create_issue", + "description": "Create a new issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "assignees": { + "description": "Usernames to assign to this issue", + "items": { + "type": "string" + }, + "type": "array" + }, + "body": { + "description": "Issue body content", + "type": "string" + }, + "labels": { + "description": "Labels to apply to this issue", + "items": { + "type": "string" + }, + "type": "array" + }, + "milestone": { + "description": "Milestone number", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "title": { + "description": "Issue title", + "type": "string" + }, + "type": { + "description": "Type of this issue", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "title" + ] + } + }, + { + "name": "create_or_update_file", + "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "branch": { + "description": "Branch to create/update the file in", + "type": "string" + }, + "content": { + "description": "Content of the file", + "type": "string" + }, + "message": { + "description": "Commit message", + "type": "string" + }, + "owner": { + "description": "Repository owner (username or organization)", + "type": "string" + }, + "path": { + "description": "Path where to create/update the file", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sha": { + "description": "Required if updating an existing file. The blob SHA of the file being replaced.", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "path", + "content", + "message", + "branch" + ] + } + }, + { + "name": "create_pull_request", + "description": "Create a new pull request in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "base": { + "description": "Branch to merge into", + "type": "string" + }, + "body": { + "description": "PR description", + "type": "string" + }, + "draft": { + "description": "Create as draft PR", + "type": "boolean" + }, + "head": { + "description": "Branch containing changes", + "type": "string" + }, + "maintainer_can_modify": { + "description": "Allow maintainer edits", + "type": "boolean" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "title": { + "description": "PR title", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "title", + "head", + "base" + ] + } + }, + { + "name": "create_repository", + "description": "Create a new GitHub repository in your account or specified organization\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "autoInit": { + "description": "Initialize with README", + "type": "boolean" + }, + "description": { + "description": "Repository description", + "type": "string" + }, + "name": { + "description": "Repository name", + "type": "string" + }, + "organization": { + "description": "Organization to create the repository in (omit to create in your personal account)", + "type": "string" + }, + "private": { + "description": "Whether repo should be private", + "type": "boolean" + } + }, + "required": [ + "name" + ] + } + }, + { + "name": "delete_file", + "description": "Delete a file from a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "branch": { + "description": "Branch to delete the file from", + "type": "string" + }, + "message": { + "description": "Commit message", + "type": "string" + }, + "owner": { + "description": "Repository owner (username or organization)", + "type": "string" + }, + "path": { + "description": "Path to the file to delete", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "path", + "message", + "branch" + ] + } + }, + { + "name": "fork_repository", + "description": "Fork a GitHub repository to your account or specified organization\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "organization": { + "description": "Organization to fork to", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "get_commit", + "description": "Get details for a commit from a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "include_diff": { + "default": true, + "description": "Whether to include file diffs and stats in the response. Default is true.", + "type": "boolean" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sha": { + "description": "Commit SHA, branch name, or tag name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "sha" + ] + } + }, + { + "name": "get_file_contents", + "description": "Get the contents of a file or directory from a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner (username or organization)", + "type": "string" + }, + "path": { + "default": "/", + "description": "Path to file/directory (directories must end with a slash '/')", + "type": "string" + }, + "ref": { + "description": "Accepts optional git refs such as `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sha": { + "description": "Accepts optional commit SHA. If specified, it will be used instead of ref", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "get_issue", + "description": "Get details of a specific issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "The number of the issue", + "type": "number" + }, + "owner": { + "description": "The owner of the repository", + "type": "string" + }, + "repo": { + "description": "The name of the repository", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issue_number" + ] + } + }, + { + "name": "get_issue_comments", + "description": "Get comments for a specific issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "Issue number", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issue_number" + ] + } + }, + { + "name": "get_label", + "description": "Get a specific label from a repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "name": { + "description": "Label name.", + "type": "string" + }, + "owner": { + "description": "Repository owner (username or organization name)", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "name" + ] + } + }, + { + "name": "get_latest_release", + "description": "Get the latest release in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "get_me", + "description": "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": {} + } + }, + { + "name": "get_release_by_tag", + "description": "Get a specific release by its tag name in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "tag": { + "description": "Tag name (e.g., 'v1.0.0')", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "tag" + ] + } + }, + { + "name": "get_tag", + "description": "Get details about a specific git tag in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "tag": { + "description": "Tag name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "tag" + ] + } + }, + { + "name": "get_team_members", + "description": "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "org": { + "description": "Organization login (owner) that contains the team.", + "type": "string" + }, + "team_slug": { + "description": "Team slug", + "type": "string" + } + }, + "required": [ + "org", + "team_slug" + ] + } + }, + { + "name": "get_teams", + "description": "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "user": { + "description": "Username to get teams for. If not provided, uses the authenticated user.", + "type": "string" + } + } + } + }, + { + "name": "list_branches", + "description": "List branches in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_commits", + "description": "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100).\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "author": { + "description": "Author username or email address to filter commits by", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sha": { + "description": "Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA.", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_issue_types", + "description": "List supported issue types for repository owner (organization).\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "The organization owner of the repository", + "type": "string" + } + }, + "required": [ + "owner" + ] + } + }, + { + "name": "list_issues", + "description": "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "after": { + "description": "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + "type": "string" + }, + "direction": { + "description": "Order direction. If provided, the 'orderBy' also needs to be provided.", + "enum": [ + "ASC", + "DESC" + ], + "type": "string" + }, + "labels": { + "description": "Filter by labels", + "items": { + "type": "string" + }, + "type": "array" + }, + "orderBy": { + "description": "Order issues by field. If provided, the 'direction' also needs to be provided.", + "enum": [ + "CREATED_AT", + "UPDATED_AT", + "COMMENTS" + ], + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "since": { + "description": "Filter by date (ISO 8601 timestamp)", + "type": "string" + }, + "state": { + "description": "Filter by state, by default both open and closed issues are returned when not provided", + "enum": [ + "OPEN", + "CLOSED" + ], + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_label", + "description": "List labels from a repository or an issue\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "Issue number - if provided, lists labels on the specific issue", + "type": "number" + }, + "owner": { + "description": "Repository owner (username or organization name) - required for all operations", + "type": "string" + }, + "repo": { + "description": "Repository name - required for all operations", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_pull_requests", + "description": "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "base": { + "description": "Filter by base branch", + "type": "string" + }, + "direction": { + "description": "Sort direction", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "head": { + "description": "Filter by head user/org and branch", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sort": { + "description": "Sort by", + "enum": [ + "created", + "updated", + "popularity", + "long-running" + ], + "type": "string" + }, + "state": { + "description": "Filter by state", + "enum": [ + "open", + "closed", + "all" + ], + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_releases", + "description": "List releases in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "list_sub_issues", + "description": "List sub-issues for a specific issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "Issue number", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (default: 1)", + "type": "number" + }, + "per_page": { + "description": "Number of results per page (max 100, default: 30)", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issue_number" + ] + } + }, + { + "name": "list_tags", + "description": "List git tags in a GitHub repository\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ] + } + }, + { + "name": "merge_pull_request", + "description": "Merge a pull request in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "commit_message": { + "description": "Extra detail for merge commit", + "type": "string" + }, + "commit_title": { + "description": "Title for merge commit", + "type": "string" + }, + "merge_method": { + "description": "Merge method", + "enum": [ + "merge", + "squash", + "rebase" + ], + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumber" + ] + } + }, + { + "name": "pull_request_read", + "description": "Get information on a specific pull request in GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "method": { + "description": "Action to specify what pull request data needs to be retrieved from GitHub. \nPossible options: \n 1. get - Get details of a specific pull request.\n 2. get_diff - Get the diff of a pull request.\n 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks.\n 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned.\n 5. get_review_comments - Get the review comments on a pull request. Use with pagination parameters to control the number of results returned.\n 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method.\n", + "enum": [ + "get", + "get_diff", + "get_status", + "get_files", + "get_review_comments", + "get_reviews" + ], + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "method", + "owner", + "repo", + "pullNumber" + ] + } + }, + { + "name": "pull_request_review_write", + "description": "Create and/or submit, delete review of a pull request.\n\nAvailable methods:\n- create: Create a new review of a pull request. If \"event\" parameter is provided, the review is submitted. If \"event\" is omitted, a pending review is created.\n- submit_pending: Submit an existing pending review of a pull request. This requires that a pending review exists for the current user on the specified pull request. The \"body\" and \"event\" parameters are used when submitting the review.\n- delete_pending: Delete an existing pending review of a pull request. This requires that a pending review exists for the current user on the specified pull request.\n\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "body": { + "description": "Review comment text", + "type": "string" + }, + "commitID": { + "description": "SHA of commit to review", + "type": "string" + }, + "event": { + "description": "Review action to perform.", + "enum": [ + "APPROVE", + "REQUEST_CHANGES", + "COMMENT" + ], + "type": "string" + }, + "method": { + "description": "The write operation to perform on pull request review.", + "enum": [ + "create", + "submit_pending", + "delete_pending" + ], + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "method", + "owner", + "repo", + "pullNumber" + ] + } + }, + { + "name": "push_files", + "description": "Push multiple files to a GitHub repository in a single commit\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "branch": { + "description": "Branch to push to", + "type": "string" + }, + "files": { + "description": "Array of file objects to push, each object with path (string) and content (string)", + "items": { + "additionalProperties": false, + "properties": { + "content": { + "description": "file content", + "type": "string" + }, + "path": { + "description": "path to the file", + "type": "string" + } + }, + "required": [ + "path", + "content" + ], + "type": "object" + }, + "type": "array" + }, + "message": { + "description": "Commit message", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "branch", + "files", + "message" + ] + } + }, + { + "name": "remove_sub_issue", + "description": "Remove a sub-issue from a parent issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "issue_number": { + "description": "The number of the parent issue", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sub_issue_id": { + "description": "The ID of the sub-issue to remove. ID is not the same as issue number", + "type": "number" + } + }, + "required": [ + "owner", + "repo", + "issue_number", + "sub_issue_id" + ] + } + }, + { + "name": "reprioritize_sub_issue", + "description": "Reprioritize a sub-issue to a different position in the parent issue's sub-issue list.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "after_id": { + "description": "The ID of the sub-issue to be prioritized after (either after_id OR before_id should be specified)", + "type": "number" + }, + "before_id": { + "description": "The ID of the sub-issue to be prioritized before (either after_id OR before_id should be specified)", + "type": "number" + }, + "issue_number": { + "description": "The number of the parent issue", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "sub_issue_id": { + "description": "The ID of the sub-issue to reprioritize. ID is not the same as issue number", + "type": "number" + } + }, + "required": [ + "owner", + "repo", + "issue_number", + "sub_issue_id" + ] + } + }, + { + "name": "request_copilot_review", + "description": "Request a GitHub Copilot code review for a pull request. Use this for automated feedback on pull requests, usually before requesting a human reviewer.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumber" + ] + } + }, + { + "name": "search_code", + "description": "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "order": { + "description": "Sort order for results", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "query": { + "description": "Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more.", + "type": "string" + }, + "sort": { + "description": "Sort field ('indexed' only)", + "type": "string" + } + }, + "required": [ + "query" + ] + } + }, + { + "name": "search_issues", + "description": "Search for issues in GitHub repositories using issues search syntax already scoped to is:issue\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "order": { + "description": "Sort order", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "owner": { + "description": "Optional repository owner. If provided with repo, only issues for this repository are listed.", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "query": { + "description": "Search query using GitHub issues search syntax", + "type": "string" + }, + "repo": { + "description": "Optional repository name. If provided with owner, only issues for this repository are listed.", + "type": "string" + }, + "sort": { + "description": "Sort field by number of matches of categories, defaults to best match", + "enum": [ + "comments", + "reactions", + "reactions-+1", + "reactions--1", + "reactions-smile", + "reactions-thinking_face", + "reactions-heart", + "reactions-tada", + "interactions", + "created", + "updated" + ], + "type": "string" + } + }, + "required": [ + "query" + ] + } + }, + { + "name": "search_pull_requests", + "description": "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "order": { + "description": "Sort order", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "owner": { + "description": "Optional repository owner. If provided with repo, only pull requests for this repository are listed.", + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "query": { + "description": "Search query using GitHub pull request search syntax", + "type": "string" + }, + "repo": { + "description": "Optional repository name. If provided with owner, only pull requests for this repository are listed.", + "type": "string" + }, + "sort": { + "description": "Sort field by number of matches of categories, defaults to best match", + "enum": [ + "comments", + "reactions", + "reactions-+1", + "reactions--1", + "reactions-smile", + "reactions-thinking_face", + "reactions-heart", + "reactions-tada", + "interactions", + "created", + "updated" + ], + "type": "string" + } + }, + "required": [ + "query" + ] + } + }, + { + "name": "search_repositories", + "description": "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "minimal_output": { + "default": true, + "description": "Return minimal repository information (default: true). When false, returns full GitHub API repository objects.", + "type": "boolean" + }, + "order": { + "description": "Sort order", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "query": { + "description": "Repository search query. Examples: 'machine learning in:name stars:>1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering.", + "type": "string" + }, + "sort": { + "description": "Sort repositories by field, defaults to best match", + "enum": [ + "stars", + "forks", + "help-wanted-issues", + "updated" + ], + "type": "string" + } + }, + "required": [ + "query" + ] + } + }, + { + "name": "search_users", + "description": "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "order": { + "description": "Sort order", + "enum": [ + "asc", + "desc" + ], + "type": "string" + }, + "page": { + "description": "Page number for pagination (min 1)", + "minimum": 1, + "type": "number" + }, + "perPage": { + "description": "Results per page for pagination (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "number" + }, + "query": { + "description": "User search query. Examples: 'john smith', 'location:seattle', 'followers:>100'. Search is automatically scoped to type:user.", + "type": "string" + }, + "sort": { + "description": "Sort users by number of followers or repositories, or when the person joined GitHub.", + "enum": [ + "followers", + "repositories", + "joined" + ], + "type": "string" + } + }, + "required": [ + "query" + ] + } + }, + { + "name": "update_issue", + "description": "Update an existing issue in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "assignees": { + "description": "New assignees", + "items": { + "type": "string" + }, + "type": "array" + }, + "body": { + "description": "New description", + "type": "string" + }, + "duplicate_of": { + "description": "Issue number that this issue is a duplicate of. Only used when state_reason is 'duplicate'.", + "type": "number" + }, + "issue_number": { + "description": "Issue number to update", + "type": "number" + }, + "labels": { + "description": "New labels", + "items": { + "type": "string" + }, + "type": "array" + }, + "milestone": { + "description": "New milestone number", + "type": "number" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "state": { + "description": "New state", + "enum": [ + "open", + "closed" + ], + "type": "string" + }, + "state_reason": { + "description": "Reason for the state change. Ignored unless state is changed.", + "enum": [ + "completed", + "not_planned", + "duplicate" + ], + "type": "string" + }, + "title": { + "description": "New title", + "type": "string" + }, + "type": { + "description": "New issue type", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "issue_number" + ] + } + }, + { + "name": "update_pull_request", + "description": "Update an existing pull request in a GitHub repository.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "base": { + "description": "New base branch name", + "type": "string" + }, + "body": { + "description": "New description", + "type": "string" + }, + "draft": { + "description": "Mark pull request as draft (true) or ready for review (false)", + "type": "boolean" + }, + "maintainer_can_modify": { + "description": "Allow maintainer edits", + "type": "boolean" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number to update", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "reviewers": { + "description": "GitHub usernames to request reviews from", + "items": { + "type": "string" + }, + "type": "array" + }, + "state": { + "description": "New state", + "enum": [ + "open", + "closed" + ], + "type": "string" + }, + "title": { + "description": "New title", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumber" + ] + } + }, + { + "name": "update_pull_request_branch", + "description": "Update the branch of a pull request with the latest changes from the base branch.\nProvided via MCP server 'github' (official: 'github-mcp-server')", + "arguments_schema": { + "type": "object", + "properties": { + "expectedHeadSha": { + "description": "The expected SHA of the pull request's HEAD ref", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumber": { + "description": "Pull request number", + "type": "number" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumber" + ] + } + } +] \ No newline at end of file diff --git a/src/benchmark/utils/__init__.py b/src/benchmark/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/benchmark/utils/json_log_utils.py b/src/benchmark/utils/json_log_utils.py new file mode 100644 index 0000000..8508488 --- /dev/null +++ b/src/benchmark/utils/json_log_utils.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from decimal import Decimal +from enum import Enum +from json import JSONDecodeError +from pathlib import Path +from typing import Any, Iterable + + +class JsonLogUtils: + """Helpers for reading/writing benchmark logs as JSON. + + Key requirements: + - logs are written pretty-printed (`indent=4`) -> multi-line JSON + - tolerate legacy files that may contain multiple JSON objects concatenated together + - handle common non-JSON types (e.g. `Enum`) via a `default=` converter + """ + + _INDENT: int = 4 + + @staticmethod + def iter_json_objects(text: str) -> Iterable[dict[str, Any]]: + """Iterate JSON objects from a string. + + Supports: + - a single pretty-printed JSON object (multi-line, e.g. `indent=4`) + - multiple JSON objects concatenated together + + This intentionally does *not* assume "1 object = 1 line". + """ + decoder = json.JSONDecoder() + idx = 0 + length = len(text) + + while True: + while idx < length and text[idx].isspace(): + idx += 1 + if idx >= length: + return + + value, end = decoder.raw_decode(text, idx) + idx = end + + if isinstance(value, dict): + yield value + continue + + if isinstance(value, list): + for item in value: + if not isinstance(item, dict): + raise ValueError(f"Expected dict items in JSON list, got: {type(item)!r}") + yield item + continue + + raise ValueError(f"Expected JSON object (dict) or list of objects, got: {type(value)!r}") + + @classmethod + def load_log_payloads(cls, path: Path) -> list[dict[str, Any]]: + """Load one or more JSON objects from a log file.""" + content = path.read_text(encoding="utf-8") + if not content.strip(): + return [] + + try: + return list(cls.iter_json_objects(content)) + except JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON log file: {path}") from e + + @staticmethod + def json_default(obj: Any) -> Any: + """Fallback conversion for non-JSON-serializable objects.""" + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, Decimal): + # Preserve precision in logs; values are read back as JSON numbers. + return str(obj) + if isinstance(obj, Path): + return str(obj) + if is_dataclass(obj): + return asdict(obj) + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if hasattr(obj, "to_dict"): + return obj.to_dict() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + @classmethod + def dumps(cls, data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=cls._INDENT, default=cls.json_default) + + @classmethod + def write(cls, path: Path, data: Any) -> None: + path.write_text(cls.dumps(data) + "\n", encoding="utf-8") diff --git a/src/benchmarking/base_logger.py b/src/benchmarking/base_logger.py deleted file mode 100644 index 6e7310f..0000000 --- a/src/benchmarking/base_logger.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -import os - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any - -from src.summarize_algorithms.core.models import ( - DialogueState, - MemoryBankDialogueState, - MetricState, - RecsumDialogueState, - Session, -) - - -class BaseLogger(ABC): - def __init__(self, logs_dir: str | Path = "logs/memory") -> None: - os.makedirs(logs_dir, exist_ok=True) - self.log_dir = Path(logs_dir) - self.logger = logging.getLogger(__name__) - - @abstractmethod - def log_iteration( - self, - system_name: str, - query: str, - iteration: int, - sessions: list[Session], - state: DialogueState | None, - metric: MetricState | None = None - ) -> dict[str, Any]: - ... - - @staticmethod - def _serialize_memories(state: DialogueState) -> dict[str, Any]: - result: dict[str, Any] = {} - - if state.code_memory_storage is not None: - result["code_memory_storage"] = state.code_memory_storage.to_dict() - if state.tool_memory_storage is not None: - result["tool_memory_storage"] = state.tool_memory_storage.to_dict() - if isinstance(state, MemoryBankDialogueState): - result["text_memory_storage"] = state.text_memory_storage.to_dict() - if isinstance(state, RecsumDialogueState): - result["text_memory"] = state.text_memory - - return result diff --git a/src/benchmarking/baseline.py b/src/benchmarking/baseline.py deleted file mode 100644 index 93f13bd..0000000 --- a/src/benchmarking/baseline.py +++ /dev/null @@ -1,73 +0,0 @@ -import os - -from typing import Any, List, Optional - -from dotenv import load_dotenv -from langchain_community.callbacks import get_openai_callback -from langchain_core.language_models import BaseChatModel -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import Runnable -from langchain_openai import ChatOpenAI -from pydantic import SecretStr - -from src.benchmarking.baseline_logger import BaselineLogger -from src.benchmarking.prompts import BASELINE_PROMPT -from src.summarize_algorithms.core.dialog import Dialog -from src.summarize_algorithms.core.models import DialogueState, OpenAIModels, Session - - -class DialogueBaseline(Dialog): - def __init__(self, system_name: str, llm: Optional[BaseChatModel] = None) -> None: - load_dotenv() - - self.system_name = system_name - - api_key: str | None = os.getenv("OPENAI_API_KEY") - if api_key is not None: - self.llm = llm or ChatOpenAI( - model=OpenAIModels.GPT_5_MINI.value, - api_key=SecretStr(api_key) - ) - else: - raise ValueError("OPENAI_API_KEY environment variable is not loaded") - - self.prompt_template = BASELINE_PROMPT - self.chain = self._build_chain() - self.prompt_tokens = 0 - self.completion_tokens = 0 - self.total_cost = 0.0 - - self.baseline_logger = BaselineLogger() - - def _build_chain(self) -> Runnable[dict[str, Any], str]: - return self.prompt_template | self.llm | StrOutputParser() - - #TODO remove iteration argument and logging - def process_dialogue(self, sessions: List[Session], query: str, iteration: int | None = None) -> DialogueState: - context_messages = [] - for session in sessions: - for message in session.messages: - context_messages.append(f"{message.role}: {message.content}") - context = "\n".join(context_messages) - with get_openai_callback() as cb: - result = self.chain.invoke({"context": context, "query": query}) - - self.prompt_tokens += cb.prompt_tokens - self.completion_tokens += cb.completion_tokens - self.total_cost += cb.total_cost - - if iteration is not None: - self.baseline_logger.log_iteration( - system_name=self.system_name, - query=query, - iteration=iteration, - sessions=sessions - ) - - return DialogueState( - dialogue_sessions=sessions, - query=query, - _response=result, - code_memory_storage=None, - tool_memory_storage=None - ) diff --git a/src/benchmarking/baseline_logger.py b/src/benchmarking/baseline_logger.py deleted file mode 100644 index 9ab25f0..0000000 --- a/src/benchmarking/baseline_logger.py +++ /dev/null @@ -1,36 +0,0 @@ -import json - -from datetime import datetime -from typing import Any - -from src.benchmarking.base_logger import BaseLogger -from src.summarize_algorithms.core.models import DialogueState, MetricState, Session - - -class BaselineLogger(BaseLogger): - def log_iteration( - self, - system_name: str, - query: str, - iteration: int, - sessions: list[Session], - state: DialogueState | None = None, - metric: MetricState | None = None - ) -> dict[str, Any]: - self.logger.info(f"Logging iteration {iteration} to {self.log_dir}") - - record = { - "timestamp": datetime.now().isoformat(), - "iteration": iteration, - "system": system_name, - "query": query, - "sessions": [s.to_dict() for s in sessions], - } - - with open(self.log_dir / (system_name + str(iteration) + ".jsonl"), "a", encoding="utf-8") as f: - f.write(json.dumps(record, ensure_ascii=False, indent=4)) - f.write("\n") - - self.logger.info(f"Saved successfully iteration {iteration} to {self.log_dir}") - - return record diff --git a/src/benchmarking/memory_logger.py b/src/benchmarking/memory_logger.py deleted file mode 100644 index 6df7016..0000000 --- a/src/benchmarking/memory_logger.py +++ /dev/null @@ -1,45 +0,0 @@ -import json - -from datetime import datetime -from typing import Any - -from src.benchmarking.base_logger import BaseLogger -from src.summarize_algorithms.core.models import DialogueState, MetricState, Session - - -class MemoryLogger(BaseLogger): - def log_iteration( - self, - system_name: str, - query: str, - iteration: int, - sessions: list[Session], - state: DialogueState | None, - metric: MetricState | None = None - ) -> dict[str, Any]: - self.logger.info(f"Logging iteration {iteration} to {self.log_dir}") - - if state is None: - raise ValueError("'state' argument is necessary for MemoryLogger") - - record = { - "timestamp": datetime.now().isoformat(), - "iteration": iteration, - "system": system_name, - "query": query, - "response": getattr(state, "response", None), - "memory": MemoryLogger._serialize_memories(state), - "sessions": [s.to_dict() for s in sessions], - } - - if metric is not None: - record["metric_name"] = metric.metric.value - record["metric_value"] = metric.value - - with open(self.log_dir / (system_name + str(iteration) + ".jsonl"), "a", encoding="utf-8") as f: - f.write(json.dumps(record, ensure_ascii=False, indent=4)) - f.write("\n") - - self.logger.info(f"Saved successfully iteration {iteration} to {self.log_dir}") - - return record diff --git a/src/benchmarking/tool_metrics/calculator.py b/src/benchmarking/tool_metrics/calculator.py deleted file mode 100644 index bfa497c..0000000 --- a/src/benchmarking/tool_metrics/calculator.py +++ /dev/null @@ -1,60 +0,0 @@ -from pathlib import Path -from typing import Any - -from src.benchmarking.base_logger import BaseLogger -from src.benchmarking.tool_metrics.base_evaluator import BaseEvaluator -from src.summarize_algorithms.core.dialog import Dialog -from src.summarize_algorithms.core.models import MetricState, Session - - -class Calculator: - """ - Class for run, evaluate (by compare with reference) and save results of llm's memory algorithms. - """ - - def __init__(self, logger: BaseLogger, path_to_save: str | Path | None = None) -> None: - """ - Initialization of Calculator class. - :param logger: the class for saving results of running and evaluating algorithms. - :param path_to_save: - """ - self.path_to_save: str | Path | None = path_to_save - self.logger = logger - - def evaluate( - self, - algorithms: list[Dialog], - evaluator_function: BaseEvaluator, - sessions: list[Session], - reference_session: Session - ) -> list[dict[str, Any]]: - """ - The main method for run and evaluate algorithm with the ast sessions. - :param algorithms: class that implements Dialog protocol. - :param evaluator_function: class inheritance of BaseEvaluator - for evaluating algorithm's results. - :param sessions: past user-tool-model interactions. - :param reference_session: reference session for comparing with algorithm's results. - :return: list[dict[str, Any]]: results of running and evaluating algorithm. - """ - user_role_messages = reference_session.get_messages_by_role("USER") - query = user_role_messages[-1] - - metrics: list[dict[str, Any]] = [] - for algorithm in algorithms: - state = algorithm.process_dialogue(sessions, query.content) - metric: MetricState = evaluator_function.evaluate(sessions, query, state) - - record: dict[str, Any] | None = self.logger.log_iteration( - algorithm.__class__.__name__, - query.content, - 1, - sessions, - state, - metric - ) - - assert record is not None, "logger didn't return a value" - - metrics.append(record) - - return metrics diff --git a/src/main.py b/src/main.py index 17517f4..1439843 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,10 @@ -from src.summarize_algorithms.core.models import ( +from src.algorithms.summarize_algorithms.core.models import ( BaseBlock, CodeBlock, Session, ToolCallBlock, ) -from src.summarize_algorithms.memory_bank.dialogue_system import ( +from src.algorithms.summarize_algorithms.memory_bank.dialogue_system import ( MemoryBankDialogueSystem, ) @@ -75,7 +75,7 @@ def sort_list(lst): arguments='{"code": "def sort_list..."}', response="The code is correct, but the sorting is inefficient for large lists.", content="Code check completed. The code is correct, but the sorting is inefficient for" - " large lists.", + " large lists.", ), ] ), diff --git a/src/prompt_templates/__init__.py b/src/prompt_templates/__init__.py new file mode 100644 index 0000000..a6bb159 --- /dev/null +++ b/src/prompt_templates/__init__.py @@ -0,0 +1 @@ +"""Shared Jinja2 prompt templates used to build a single unified SystemMessage.""" diff --git a/src/prompt_templates/bridge_to_conversation.j2 b/src/prompt_templates/bridge_to_conversation.j2 new file mode 100644 index 0000000..1e314cc --- /dev/null +++ b/src/prompt_templates/bridge_to_conversation.j2 @@ -0,0 +1,8 @@ +### EXECUTION INSTRUCTION +The System Instruction ends here. +Immediately following this message, you will be provided with the **Conversation History**. +Your task is to: +1. Analyze the Conversation History to understand the context and state. +2. Identify the **LAST** message in the history as the `LatestUserRequest`. +3. Create the JSON Action Plan to address that specific request, utilizing the Memory and Tools defined above. +4. Do not do real tool calls, just create the plan. diff --git a/src/prompt_templates/introduction.j2 b/src/prompt_templates/introduction.j2 new file mode 100644 index 0000000..0a13d1b --- /dev/null +++ b/src/prompt_templates/introduction.j2 @@ -0,0 +1,18 @@ +**Role & Objective** +You are an expert autonomous agent planner. You are given a multi-session conversation between a user and an LLM, including prior tool calls and results. +Your goal is to plan how to answer the **latest user request** (which is the final message in the conversation history provided after this instruction) without executing any tools immediately. +Instead, produce a strict JSON action plan that lists, in the correct execution order, every tool you would call (with arguments) and any non-tool actions (reasoning, synthesis) required. + +**Inputs Overview** +- **Memory & Context:** Distilled notes or summaries from previous interactions (provided below, if available). +- **Tools Catalog:** A list of executable functions (provided below). +- **Conversation History:** The actual chat logs (provided as separate messages following this system instruction). + +**Planning Guidelines** +1. **Dependency Management:** Order steps exactly as they should be executed. Use `depends_on` to express explicit dependencies. +2. **Tool Calls:** Use `kind = "tool_call"` for planned invocations. Populate `args` precisely according to the schema. +3. **Missing Data:** If an argument value is unknown at planning time, set it to `null` and add an entry in `assumptions` describing what data needs to be discovered. +4. **Final Answer:** Use `kind = "final_answer"` for the final composition step. Do not write the actual answer text, just the plan to compose it. +5. **Memory Usage:** If you rely on information provided in the Memory sections, strictly reference it in the `memory_used` field of the JSON output. + +--- diff --git a/src/prompt_templates/memory_injection.j2 b/src/prompt_templates/memory_injection.j2 new file mode 100644 index 0000000..1612ed2 --- /dev/null +++ b/src/prompt_templates/memory_injection.j2 @@ -0,0 +1,19 @@ +{% if recap is not none and recap|trim != "" %} +### RECAP: +{{ recap }} +{% endif %} + +{% if memory_bank is not none and memory_bank|trim != "" %} +### MEMORY BANK: +{{ memory_bank }} +{% endif %} + +{% if code_knowledge is not none and code_knowledge|trim != "" %} +### CODE KNOWLEDGE: +{{ code_knowledge }} +{% endif %} + +{% if tool_memory is not none and tool_memory|trim != "" %} +### TOOL MEMORY: +{{ tool_memory }} +{% endif %} diff --git a/src/prompt_templates/schema_and_tool.j2 b/src/prompt_templates/schema_and_tool.j2 new file mode 100644 index 0000000..d6817d0 --- /dev/null +++ b/src/prompt_templates/schema_and_tool.j2 @@ -0,0 +1,25 @@ +--- + +**Tools Catalog** +You may only plan calls to the following tools. Do not invent tools. + +{{ tools }} + +**Output JSON Schema** +Return only JSON that conforms to this schema. + +{{ schema_json }} + +**MemoryArtifacts** +{% if memory_mode == "baseline" %} +None (stateless) or session-only context. +{% else %} +May include injected sections such as RECAP / MEMORY BANK / CODE KNOWLEDGE / TOOL MEMORY when available. +{% endif %} + +{% if examples is not none and examples|trim != "" %} +**Examples** +{{ examples }} +{% endif %} + +--- diff --git a/src/summarize_algorithms/core/base_dialogue_system.py b/src/summarize_algorithms/core/base_dialogue_system.py deleted file mode 100644 index f2afbc1..0000000 --- a/src/summarize_algorithms/core/base_dialogue_system.py +++ /dev/null @@ -1,137 +0,0 @@ -import functools -import os - -from abc import ABC, abstractmethod -from typing import Any, Optional, Type - -from dotenv import load_dotenv -from langchain_community.callbacks import get_openai_callback -from langchain_core.embeddings import Embeddings -from langchain_core.language_models import BaseChatModel -from langchain_core.prompts import PromptTemplate -from langchain_openai import ChatOpenAI -from langgraph.constants import END -from langgraph.graph import StateGraph -from langgraph.graph.state import CompiledStateGraph -from pydantic import SecretStr - -from src.benchmarking.memory_logger import MemoryLogger -from src.summarize_algorithms.core.graph_nodes import ( - UpdateState, - generate_response_node, - should_continue_memory_update, - update_memory_node, -) -from src.summarize_algorithms.core.models import ( - DialogueState, - OpenAIModels, - Session, - WorkflowNode, -) -from src.summarize_algorithms.core.prompts import RESPONSE_GENERATION_PROMPT -from src.summarize_algorithms.core.response_generator import ResponseGenerator - - -class BaseDialogueSystem(ABC): - def __init__( - self, - llm: Optional[BaseChatModel] = None, - embed_code: bool = False, - embed_tool: bool = False, - embed_model: Optional[Embeddings] = None, - max_session_id: int = 3, - ) -> None: - load_dotenv() - - api_key: str | None = os.getenv("OPENAI_API_KEY") - if api_key is not None: - self.llm = llm or ChatOpenAI( - model=OpenAIModels.GPT_5_MINI.value, - api_key=SecretStr(api_key) - ) - else: - raise ValueError("OPENAI_API_KEY environment variable is not loaded") - - self.summarizer = self._build_summarizer() - self.response_generator = ResponseGenerator( - self.llm, self._get_response_prompt_template() - ) - self.graph = self._build_graph() - self.state: Optional[DialogueState] = None - self.embed_code = embed_code - self.embed_tool = embed_tool - self.embed_model = embed_model - self.max_session_id = max_session_id - self.prompt_tokens = 0 - self.completion_tokens = 0 - self.total_cost = 0.0 - - self.memory_logger = MemoryLogger() - self.iteration = 0 - - @abstractmethod - def _build_summarizer(self) -> Any: - pass - - @staticmethod - def _get_response_prompt_template() -> PromptTemplate: - return RESPONSE_GENERATION_PROMPT - - @abstractmethod - def _get_initial_state(self, sessions: list[Session], query: str) -> DialogueState: - pass - - @property - @abstractmethod - def _get_dialogue_state_class(self) -> Type[DialogueState]: - pass - - def _build_graph(self) -> CompiledStateGraph: - workflow = StateGraph(self._get_dialogue_state_class) - - workflow.add_node( - WorkflowNode.UPDATE_MEMORY.value, - functools.partial(update_memory_node, self.summarizer), - ) - workflow.add_node( - WorkflowNode.GENERATE_RESPONSE.value, - functools.partial(generate_response_node, self.response_generator), - ) - - workflow.set_entry_point(WorkflowNode.UPDATE_MEMORY.value) - - workflow.add_conditional_edges( - WorkflowNode.UPDATE_MEMORY.value, - should_continue_memory_update, - { - UpdateState.CONTINUE_UPDATE.value: WorkflowNode.UPDATE_MEMORY.value, - UpdateState.FINISH_UPDATE.value: WorkflowNode.GENERATE_RESPONSE.value, - }, - ) - - workflow.add_edge(WorkflowNode.GENERATE_RESPONSE.value, END) - - return workflow.compile() - - def process_dialogue(self, sessions: list[Session], query: str) -> DialogueState: - initial_state = self._get_initial_state(sessions, query) - with get_openai_callback() as cb: - self.state = self._get_dialogue_state_class( - **self.graph.invoke(initial_state) - ) - - self.prompt_tokens += cb.prompt_tokens - self.completion_tokens += cb.completion_tokens - self.total_cost += cb.total_cost - - self.iteration += 1 - system_name = self.__class__.__name__ - self.memory_logger.log_iteration( - system_name, - query, - self.iteration, - sessions, - self.state, - ) - - return self.state if self.state is not None else initial_state diff --git a/src/summarize_algorithms/core/base_summarizer.py b/src/summarize_algorithms/core/base_summarizer.py deleted file mode 100644 index 20b53d3..0000000 --- a/src/summarize_algorithms/core/base_summarizer.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - -from langchain_core.language_models import BaseChatModel -from langchain_core.prompts import PromptTemplate -from langchain_core.runnables import Runnable - - -class BaseSummarizer(ABC): - def __init__(self, llm: BaseChatModel, prompt: PromptTemplate) -> None: - self.llm = llm - self.prompt = prompt - self.chain = self._build_chain() - - @abstractmethod - def _build_chain(self) -> Runnable[dict[str, Any], Any]: - pass - - @abstractmethod - def summarize(self, *args: Any, **kwargs: Any) -> Any: - pass diff --git a/src/summarize_algorithms/core/dialog.py b/src/summarize_algorithms/core/dialog.py deleted file mode 100644 index ce5fbe5..0000000 --- a/src/summarize_algorithms/core/dialog.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Protocol - -from src.summarize_algorithms.core.models import DialogueState, Session - - -class Dialog(Protocol): - def process_dialogue(self, sessions: list[Session], query: str) -> DialogueState: - ... diff --git a/src/summarize_algorithms/core/models.py b/src/summarize_algorithms/core/models.py deleted file mode 100644 index 902b6a6..0000000 --- a/src/summarize_algorithms/core/models.py +++ /dev/null @@ -1,168 +0,0 @@ -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Iterator, Optional - -from dataclasses_json import dataclass_json - - -class OpenAIModels(Enum): - GPT_3_5_TURBO = "gpt-3.5-turbo" - GPT_4_1_MINI = "gpt-4.1-mini" - GPT_4_O = "gpt-4o" - GPT_4_1 = "gpt-4.1" - GPT_5_NANO = "gpt-5-nano" - GPT_5_MINI = "gpt-5-mini" - - -@dataclass -class BaseBlock: - role: str - content: str - - def __str__(self) -> str: - return f"{self.role}: {self.content}" - - -@dataclass -class CodeBlock(BaseBlock): - code: str - - -@dataclass -class ToolCallBlock(BaseBlock): - id: str - name: str - arguments: str - response: str - - -class Session: - def __init__(self, messages: list[BaseBlock]) -> None: - self.messages = messages - - def __len__(self) -> int: - return len(self.messages) - - def __str__(self) -> str: - result_messages = [] - for msg in self.messages: - if isinstance(msg, CodeBlock): - result_messages.append(f"{msg.role}: {msg.code}") - if isinstance(msg, ToolCallBlock): - result_messages.append( - f"Tool Call [{msg.id}]: {msg.name} - {msg.arguments} -> {msg.response}" - ) - else: - result_messages.append(f"{msg.role}: {msg.content}") - return "\n".join(result_messages) - - def __getitem__(self, index: int) -> BaseBlock: - return self.messages[index] - - def __iter__(self) -> Iterator[BaseBlock]: - return iter(self.messages) - - def to_dict(self) -> dict[str, Any]: - result_messages = [] - for msg in self.messages: - if isinstance(msg, CodeBlock): - result_messages.append({ - "type": "code", - "role": msg.role, - "code": msg.code, - }) - elif isinstance(msg, ToolCallBlock): - result_messages.append({ - "type": "tool_call", - "id": msg.id, - "name": msg.name, - "arguments": msg.arguments, - "response": msg.response, - }) - else: - result_messages.append({ - "type": "text", - "role": msg.role, - "content": msg.content, - }) - return {"messages": result_messages} - - def get_messages_by_role(self, role: str) -> list[BaseBlock]: - return [msg for msg in self.messages if msg.role == role] - - def get_text_blocks(self) -> list[BaseBlock]: - return [ - msg - for msg in self.messages - if not (isinstance(msg, CodeBlock) or isinstance(msg, ToolCallBlock)) - ] - - def get_code_blocks(self) -> list[CodeBlock]: - return [msg for msg in self.messages if isinstance(msg, CodeBlock)] - - def get_tool_calls(self) -> list[ToolCallBlock]: - return [msg for msg in self.messages if isinstance(msg, ToolCallBlock)] - - -@dataclass_json -@dataclass -class DialogueState: - from src.summarize_algorithms.core.memory_storage import MemoryStorage - - dialogue_sessions: list[Session] - code_memory_storage: Optional[MemoryStorage] - tool_memory_storage: Optional[MemoryStorage] - query: str - current_session_index: int = 0 - _response: Optional[str] = None - - @property - def response(self) -> str: - if self._response is None: - raise ValueError("Response has not been generated yet.") - return self._response - - @property - def current_context(self) -> Session: - return self.dialogue_sessions[-1] - - -@dataclass_json -@dataclass -class RecsumDialogueState(DialogueState): - text_memory: list[list[str]] = field(default_factory=list) - - @property - def latest_memory(self) -> str: - return "\n".join(self.text_memory[-1]) if self.text_memory else "" - - -@dataclass_json -@dataclass -class MemoryBankDialogueState(DialogueState): - from src.summarize_algorithms.core.memory_storage import MemoryStorage - - text_memory_storage: MemoryStorage = field(default_factory=MemoryStorage) - - -class MetricType(Enum): - COHERENCE = "coherence" - - -@dataclass_json -@dataclass -class MetricState: - metric: MetricType - value: float | int - - -class WorkflowNode(Enum): - UPDATE_MEMORY = "update_memory" - GENERATE_RESPONSE = "generate_response" - - -class UpdateState(Enum): - CONTINUE_UPDATE = "continue_update" - FINISH_UPDATE = "finish_update" - - diff --git a/src/summarize_algorithms/core/response_generator.py b/src/summarize_algorithms/core/response_generator.py deleted file mode 100644 index 4399dcb..0000000 --- a/src/summarize_algorithms/core/response_generator.py +++ /dev/null @@ -1,30 +0,0 @@ -from langchain_core.language_models import BaseChatModel -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import PromptTemplate -from langchain_core.runnables import Runnable - - -class ResponseGenerator: - def __init__(self, llm: BaseChatModel, prompt_template: PromptTemplate) -> None: - self.llm = llm - self.prompt_template = prompt_template - self.chain = self._build_chain() - - def _build_chain(self) -> Runnable: - return self.prompt_template | self.llm | StrOutputParser() - - def generate_response( - self, dialogue_memory: str, code_memory: str, tool_memory: str, query: str - ) -> str: - try: - response = self.chain.invoke( - { - "dialogue_memory": dialogue_memory, - "code_memory": code_memory, - "tool_memory": tool_memory, - "query": query, - } - ) - return response - except Exception as e: - raise ConnectionError(f"API request failed: {str(e)}") from e diff --git a/src/summarize_algorithms/memory_bank/dialogue_system.py b/src/summarize_algorithms/memory_bank/dialogue_system.py deleted file mode 100644 index 683dd5c..0000000 --- a/src/summarize_algorithms/memory_bank/dialogue_system.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Type - -from src.summarize_algorithms.core.base_dialogue_system import BaseDialogueSystem -from src.summarize_algorithms.core.memory_storage import MemoryStorage -from src.summarize_algorithms.core.models import MemoryBankDialogueState, Session -from src.summarize_algorithms.memory_bank.prompts import SESSION_SUMMARY_PROMPT -from src.summarize_algorithms.memory_bank.summarizer import SessionSummarizer - - -class MemoryBankDialogueSystem(BaseDialogueSystem): - def _build_summarizer(self) -> SessionSummarizer: - return SessionSummarizer(self.llm, SESSION_SUMMARY_PROMPT) - - def _get_initial_state( - self, sessions: list[Session], query: str - ) -> MemoryBankDialogueState: - return MemoryBankDialogueState( - dialogue_sessions=sessions, - code_memory_storage=( - MemoryStorage( - embeddings=self.embed_model, max_session_id=self.max_session_id - ) - if self.embed_code - else None - ), - tool_memory_storage=( - MemoryStorage( - embeddings=self.embed_model, max_session_id=self.max_session_id - ) - if self.embed_tool - else None - ), - query=query, - text_memory_storage=MemoryStorage( - embeddings=self.embed_model, max_session_id=self.max_session_id - ), - ) - - @property - def _get_dialogue_state_class(self) -> Type: - return MemoryBankDialogueState diff --git a/src/summarize_algorithms/recsum/dialogue_system.py b/src/summarize_algorithms/recsum/dialogue_system.py deleted file mode 100644 index 4a81722..0000000 --- a/src/summarize_algorithms/recsum/dialogue_system.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Type - -from src.summarize_algorithms.core.base_dialogue_system import BaseDialogueSystem -from src.summarize_algorithms.core.memory_storage import MemoryStorage -from src.summarize_algorithms.core.models import RecsumDialogueState, Session -from src.summarize_algorithms.recsum.prompts import MEMORY_UPDATE_PROMPT_TEMPLATE -from src.summarize_algorithms.recsum.summarizer import RecursiveSummarizer - - -class RecsumDialogueSystem(BaseDialogueSystem): - def _build_summarizer(self) -> RecursiveSummarizer: - return RecursiveSummarizer(self.llm, MEMORY_UPDATE_PROMPT_TEMPLATE) - - def _get_initial_state( - self, sessions: list[Session], query: str - ) -> RecsumDialogueState: - return RecsumDialogueState( - dialogue_sessions=sessions, - code_memory_storage=MemoryStorage( - embeddings=self.embed_model, max_session_id=self.max_session_id - ), - tool_memory_storage=MemoryStorage( - embeddings=self.embed_model, max_session_id=self.max_session_id - ), - query=query, - ) - - @property - def _get_dialogue_state_class(self) -> Type: - return RecsumDialogueState diff --git a/src/utils/configure_logs.py b/src/utils/configure_logs.py new file mode 100644 index 0000000..722858a --- /dev/null +++ b/src/utils/configure_logs.py @@ -0,0 +1,54 @@ +import json +import logging +import os +import traceback + +from logging import LogRecord +from logging.handlers import RotatingFileHandler + +from colorlog import ColoredFormatter + + +class DadaJsonFormatter(logging.Formatter): + def format(self, record: LogRecord) -> str: + log_record = { + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "timestamp": str(int(record.created)) + } + + if record.exc_info: + log_record["exception"] = "".join(traceback.format_exception(*record.exc_info)) + + return json.dumps(log_record) + + +def configure_logs(logdir: str | None = None, loglevel: int = logging.INFO, log_file: str | None = None) -> None: + logger = logging.getLogger() + logger.setLevel(loglevel) + + if logdir and log_file: + log_file = os.path.join(logdir, log_file) + os.makedirs(logdir, exist_ok=True) + file_handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5) + file_handler.setFormatter(DadaJsonFormatter()) + file_handler.setLevel(loglevel) + logger.addHandler(file_handler) + + console_handler = logging.StreamHandler() + console_formatter = ColoredFormatter( + "%(log_color)s%(asctime)s - %(levelname)-8s%(reset)s %(name)s - %(blue)s%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + reset=True, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + } + ) + console_handler.setFormatter(console_formatter) + console_handler.setLevel(loglevel) + logger.addHandler(console_handler) diff --git a/src/utils/parse.py b/src/utils/parse.py new file mode 100644 index 0000000..2615b8b --- /dev/null +++ b/src/utils/parse.py @@ -0,0 +1,23 @@ +from pathlib import Path + +from src.algorithms.summarize_algorithms.core.models import ToolCallBlock +from src.benchmark.tool_plan_benchmarking.load_session import Loader +from src.benchmark.tool_plan_benchmarking.run import BASE_DATA_PATH, JSON_FILE_TEMPLATE + +path_data_type_1: Path = Path(BASE_DATA_PATH) / "data_type_1" +for file in path_data_type_1.glob(JSON_FILE_TEMPLATE): + session = Loader.load_session_data_type_1(file) + for message in session.messages: + if isinstance(message, ToolCallBlock): + if message.id == "c584b0f6-ee9b-4ad0-aa90-a337fb92c9b7": + print(file) + print(session.messages.index(message)) + +path_data_type_2: Path = Path(BASE_DATA_PATH) / "data_type_2" +for file in path_data_type_2.glob(JSON_FILE_TEMPLATE): + session = Loader.load_session_data_type_2(file) + for message in session.messages: + if isinstance(message, ToolCallBlock): + if message.id == "c584b0f6-ee9b-4ad0-aa90-a337fb92c9b7": + print(file) + print(session.messages.index(message)) diff --git a/src/benchmarking/semantic_similarity.py b/src/utils/semantic_similarity.py similarity index 60% rename from src/benchmarking/semantic_similarity.py rename to src/utils/semantic_similarity.py index e70aee7..181b754 100644 --- a/src/benchmarking/semantic_similarity.py +++ b/src/utils/semantic_similarity.py @@ -1,3 +1,5 @@ +import json + from dataclasses import dataclass from typing import Any @@ -27,6 +29,20 @@ def __init__( self.tokenizer = tiktoken.get_encoding("cl100k_base") self.use_tokenizer = use_tokenizer + @staticmethod + def _to_text(value: Any) -> str: + """Convert arbitrary JSON-ish values to text for embedding.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + except TypeError: + return str(value) + def _tokenize(self, text: str) -> np.ndarray: if not text or not text.strip(): return np.array([]) @@ -83,3 +99,38 @@ def compute_similarity( return SemanticSimilarityResult( precision=float(precision), recall=float(recall), f1=float(f1) ) + + def calculate(self, sentence_a: str, sentence_b: str) -> float: + """Embed two sentences and return their cosine similarity.""" + + sentence_a = sentence_a.strip() + sentence_b = sentence_b.strip() + if not sentence_a or not sentence_b: + raise ValueError("Sentences must be non-empty.") + + vecs = self.embeddings.embed_documents([sentence_a, sentence_b]) + vec_a = np.asarray(vecs[0], dtype=float).reshape(1, -1) + vec_b = np.asarray(vecs[1], dtype=float).reshape(1, -1) + return float(cosine_similarity(vec_a, vec_b)[0][0]) + + def compare_json(self, json_a: dict[str, Any], json_b: dict[str, Any]) -> float: + """Compare two JSON objects and return the average similarity score. + + Values are coerced to text via `_to_text()` before embedding. + """ + + common_keys = set(json_a.keys()).intersection(set(json_b.keys())) #TODO only common keys?? + if not common_keys: + return 0.0 + + similarities: list[float] = [] + for key in common_keys: + similarity = self.calculate( + self._to_text(json_a[key]), + self._to_text(json_b[key]), + ) + if similarity >= 0.7: + print("yes") + similarities.append(similarity) + + return float(np.mean(similarities)) diff --git a/src/utils/system_prompt_builder.py b/src/utils/system_prompt_builder.py new file mode 100644 index 0000000..cdbb389 --- /dev/null +++ b/src/utils/system_prompt_builder.py @@ -0,0 +1,83 @@ +import json + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from jinja2 import Environment, FileSystemLoader, select_autoescape + + +@dataclass(frozen=True) +class MemorySections: + """Optional memory sections to inject between introduction and tool/schema blocks.""" + + recap: str | None = None + memory_bank: str | None = None + code_knowledge: str | None = None + tool_memory: str | None = None + + +class SystemPromptBuilder: + """Builds a single unified system prompt using repository Jinja2 templates.""" + + def __init__(self) -> None: + templates_dir = Path(__file__).resolve().parents[1] / "prompt_templates" + self._env = Environment( + loader=FileSystemLoader(str(templates_dir)), + autoescape=select_autoescape(disabled_extensions=("j2",)), + trim_blocks=True, + lstrip_blocks=True, + ) + + def build( + self, + *, + schema: dict[str, Any] | None, + tools: list[dict[str, Any]] | None, + memory: MemorySections, + memory_mode: str, + examples: str = "", + ) -> str: + """ + Build the unified system prompt in the required order. + + Order: + 1) introduction.j2 + 2) memory blocks (conditional) + 3) schema_and_tool.j2 + 4) bridge_to_conversation.j2 + + :param schema: JSON schema for model output (structured output). + :param tools: JSON tools for tool-calling. + :param memory: optional memory sections. + :param memory_mode: "baseline" or "memory" (affects MemoryArtifacts description). + :param examples: optional examples block. + :return: str: rendered system prompt. + """ + intro = self._env.get_template("introduction.j2").render().strip() + + memory_text = self._env.get_template("memory_injection.j2").render( + recap=memory.recap, + memory_bank=memory.memory_bank, + code_knowledge=memory.code_knowledge, + tool_memory=memory.tool_memory, + ).strip() + + schema_json = json.dumps(schema or {}, ensure_ascii=False, indent=4) + tools_json = json.dumps(tools or {}, ensure_ascii=False, indent=4) + + schema_and_tool = self._env.get_template("schema_and_tool.j2").render( + tools=tools_json, + schema_json=schema_json, + memory_mode=memory_mode, + examples=examples, + ).strip() + + bridge = self._env.get_template("bridge_to_conversation.j2").render().strip() + + parts = [intro] + if memory_text != "": + parts.append(memory_text) + parts.extend([schema_and_tool, bridge]) + + return "\n\n".join(parts).strip() diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 0aafd08..d1e3489 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -1,40 +1,59 @@ +from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock import pytest -from src.benchmarking.tool_metrics.calculator import Calculator -from src.summarize_algorithms.core.models import ( +from src.algorithms.dialogue import Dialogue +from src.algorithms.summarize_algorithms.core.models import ( BaseBlock, DialogueState, - MetricState, - MetricType, Session, ) +from src.benchmark.models.dtos import BaseRecord, MetricState +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.calculator import Calculator @pytest.fixture def fake_logger(): logger = MagicMock() - logger.log_iteration.return_value = {"logged": True, "system_name": "FakeAlgo"} + + logger.log_iteration.return_value = BaseRecord( + timestamp=datetime.now().isoformat(), + iteration=1, + system="FakeAlgo", + query="What is AI?", + response={"some": "response"}, + sessions=[], + prepared_messages=[], + metric=[MetricState(metric_name=MetricType("COHERENCE"), metric_value=0.95)] + ) return logger @pytest.fixture def fake_evaluator(): evaluator = MagicMock() - evaluator.evaluate.return_value = MetricState(metric=MetricType.COHERENCE, value=0.95) + evaluator.evaluate.return_value = MetricState( + metric_name=MetricType("COHERENCE"), + metric_value=0.95 + ) return evaluator @pytest.fixture def fake_algorithm(): - algo = MagicMock() - algo.__class__.__name__ = "FakeAlgorithm" + algo = MagicMock(spec=Dialogue) + algo.system_name = "FakeAlgorithm" + fake_state = DialogueState( dialogue_sessions=[], code_memory_storage=None, tool_memory_storage=None, query="What is AI?", + _response={"some": "response"}, + prepared_messages=[] ) algo.process_dialogue.return_value = fake_state return algo @@ -53,28 +72,37 @@ def sessions(): def reference_session(): return Session([ BaseBlock(role="USER", content="Tell me about AI."), - BaseBlock(role="ASSISTANT", content="AI stands for artificial intelligence."), - BaseBlock(role="USER", content="What is AI?") + BaseBlock(role="ASSISTANT", content="AI stands for artificial intelligence.") ]) def test_evaluate_success(fake_logger, fake_evaluator, fake_algorithm, sessions, reference_session): - calc = Calculator(logger=fake_logger) - - results = calc.evaluate( + results = Calculator.evaluate( algorithms=[fake_algorithm], - evaluator_function=fake_evaluator, + evaluator_functions=[fake_evaluator], sessions=sessions[:1], - reference_session=reference_session + reference=reference_session.messages, + logger=fake_logger, + prompt="What is AI?", + subdirectory=Path("test_subdir"), + iteration=1 ) assert isinstance(results, list) assert len(results) == 1 - assert results[0]["logged"] + assert isinstance(results[0], BaseRecord) + + assert results[0].system == "FakeAlgo" + assert results[0].iteration == 1 + assert results[0].query == "What is AI?" + assert isinstance(results[0].metric, list) + assert len(results[0].metric) > 0 + assert results[0].metric[0].metric_value == 0.95 fake_algorithm.process_dialogue.assert_called_once() fake_evaluator.evaluate.assert_called_once() - fake_logger.log_iteration.assert_called_once() - query_passed = fake_algorithm.process_dialogue.call_args[0][1] - assert query_passed == "What is AI?" + logger_call_args = fake_logger.log_iteration.call_args + + assert logger_call_args[0][0] == "FakeAlgorithm" + assert logger_call_args[0][6][0].metric_value == 0.95 diff --git a/tests/test_dialogue_baseline_prompt_transfer.py b/tests/test_dialogue_baseline_prompt_transfer.py new file mode 100644 index 0000000..ea27021 --- /dev/null +++ b/tests/test_dialogue_baseline_prompt_transfer.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA, TOOLS +from src.utils.system_prompt_builder import MemorySections, SystemPromptBuilder + + +@dataclass +class InvocationCapture: + messages: list[BaseMessage] | None = None + + +class CapturingChain: + def __init__(self, capture: InvocationCapture, response: Any = "ok") -> None: + self._capture = capture + self._response = response + + def invoke(self, messages: list[BaseMessage]) -> Any: + self._capture.messages = messages + return self._response + + +class DummyCallback: + prompt_tokens = 0 + completion_tokens = 0 + total_cost = 0.0 + + def __enter__(self) -> DummyCallback: + return self + + def __exit__(self, exc_type, exc, tb) -> None: # noqa: ANN001 + return None + + +@pytest.fixture +def fake_llm() -> MagicMock: + llm = MagicMock(spec=BaseChatModel) + llm.get_num_tokens_from_messages.side_effect = lambda msgs: len(msgs) * 10 + return llm + + +def test_dialogue_baseline_transfers_full_prompt_as_message_list(monkeypatch: pytest.MonkeyPatch, fake_llm: MagicMock): + """Baseline must pass a list[BaseMessage] (not dict vars) and use the unified system templates.""" + + # Patch LangChain callback context used in DialogueBaseline + import src.algorithms.simple_algorithms.dialogue_baseline as baseline_mod + + monkeypatch.setattr(baseline_mod, "get_openai_callback", lambda: DummyCallback()) + + baseline = DialogueBaseline.__new__(DialogueBaseline) + baseline.system_name = "Baseline" + baseline.llm = fake_llm + baseline._prompt_builder = SystemPromptBuilder() + baseline.prompt_tokens = 0 + baseline.completion_tokens = 0 + baseline.total_cost = 0.0 + + capture = InvocationCapture() + baseline._build_chain = lambda *_args, **_kwargs: CapturingChain(capture) # type: ignore[method-assign] + + sessions = [Session([BaseBlock(role="USER", content="Hi")])] + baseline.process_dialogue( + sessions=sessions, + system_prompt="User question", + structure=PLAN_SCHEMA, + tools=None, + ) + + assert capture.messages is not None + assert isinstance(capture.messages, list) + assert isinstance(capture.messages[0], SystemMessage) + assert isinstance(capture.messages[-1], HumanMessage) + assert capture.messages[-1].content == "Hi" or capture.messages[-1].content != "" # sanity + + expected_system = baseline._prompt_builder.build( + schema=PLAN_SCHEMA, + tools=TOOLS, + memory=MemorySections(), + memory_mode="baseline", + ) + assert capture.messages[0].content == expected_system + + # Cross-check with ResponseGenerator system prompt for empty memory. + rg = ResponseGenerator(fake_llm, structure=PLAN_SCHEMA, max_prompt_tokens=None) + rg_system = rg._build_system_message(memory=MemorySections()).content + assert rg_system == expected_system diff --git a/tests/test_f1_tool_evaluator.py b/tests/test_f1_tool_evaluator.py new file mode 100644 index 0000000..df539f4 --- /dev/null +++ b/tests/test_f1_tool_evaluator.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import json + +from decimal import Decimal + +import pytest + +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + DialogueState, + Session, + ToolCallBlock, +) +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.evaluators.f1_tool_evaluator import ( + F1ToolEvaluator, +) + + +def _state_with_plan(tool_calls: list[tuple[str, dict]]) -> DialogueState: + state = DialogueState( + dialogue_sessions=[], + prepared_messages=[], + code_memory_storage=None, + tool_memory_storage=None, + query="q", + ) + state._response = { + "plan_steps": [ + {"kind": "tool_call", "name": name, "args": args, "id": f"s{i}", "description": "", "depends_on": []} + for i, (name, args) in enumerate(tool_calls, start=1) + ] + } + return state + + +def _reference(*names_and_args: tuple[str, dict]) -> list[BaseBlock]: + blocks: list[BaseBlock] = [] + for i, (name, args) in enumerate(names_and_args, start=1): + blocks.append( + ToolCallBlock( + role="ASSISTANT", + id=f"t{i}", + name=name, + arguments=json.dumps(args), + response="", + content="", + ) + ) + return blocks + + +def test_f1_simple_counts_tool_names() -> None: + evaluator = F1ToolEvaluator(mode="simple") + state = _state_with_plan([("list_dir", {}), ("read_file", {})]) + ref = _reference(("read_file", {"a": 1}), ("search_for_text", {})) + + metric = evaluator.evaluate([Session([])], "q", state, ref) + + assert metric.metric_name == MetricType.F1_TOOL + # predicted={list_dir,read_file}, reference={read_file,search_for_text} => tp=1 fp=1 fn=1 => f1=0.5 + assert metric.metric_value == Decimal("0.5") + + +def test_f1_strict_requires_args_match() -> None: + evaluator = F1ToolEvaluator(mode="strict") + state = _state_with_plan([("read_file", {"path": "a"})]) + ref = _reference(("read_file", {"path": "b"})) + + metric = evaluator.evaluate([Session([])], "q", state, ref) + + assert metric.metric_name == MetricType.F1_TOOL_STRICT + assert metric.metric_value == Decimal("0") diff --git a/tests/test_generate_response_node_prompt.py b/tests/test_generate_response_node_prompt.py new file mode 100644 index 0000000..ceeb90b --- /dev/null +++ b/tests/test_generate_response_node_prompt.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + +from src.algorithms.summarize_algorithms.core.graph_nodes import generate_response_node +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + RecsumDialogueState, + Session, +) +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA + + +@dataclass +class InvocationCapture: + messages: list[BaseMessage] | None = None + + +class CapturingChain: + def __init__(self, capture: InvocationCapture, response: Any = "ok") -> None: + self._capture = capture + self._response = response + + def invoke(self, messages: list[BaseMessage]) -> Any: + self._capture.messages = messages + return self._response + + +def test_generate_response_node_invokes_llm_with_full_history_and_memory() -> None: + fake_llm = MagicMock(spec=BaseChatModel) + fake_llm.get_num_tokens_from_messages.return_value = 10 + + capture = InvocationCapture() + rg = ResponseGenerator(fake_llm, structure=PLAN_SCHEMA, max_prompt_tokens=None) + rg._chain = CapturingChain(capture, response={"plan_steps": []}) # type: ignore[assignment] + + state = RecsumDialogueState( + dialogue_sessions=[], + prepared_messages=[], + code_memory_storage=None, + tool_memory_storage=None, + query="What now?", + last_session=Session([BaseBlock(role="USER", content="Hi")]), + text_memory=[["memory line"]], + ) + + out_state = generate_response_node(rg, state) + + assert capture.messages is not None + assert out_state.prepared_messages == capture.messages + assert out_state.response == {"plan_steps": []} + + assert isinstance(capture.messages[0], SystemMessage) + assert "### RECAP:" in capture.messages[0].content + assert "memory line" in capture.messages[0].content + + assert isinstance(capture.messages[-1], HumanMessage) + assert capture.messages[-1].content == "What now?" diff --git a/tests/test_llm_as_a_judge_base_evaluator_system_prompt_builder.py b/tests/test_llm_as_a_judge_base_evaluator_system_prompt_builder.py new file mode 100644 index 0000000..fc4e93f --- /dev/null +++ b/tests/test_llm_as_a_judge_base_evaluator_system_prompt_builder.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel + +from src.benchmark.models.dtos import MetricState +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.evaluators.llm_as_a_judge_base_evaluator import LLMAsAJudgeBaseEvaluator + + +class DummyResult(BaseModel): + score: int + + +class DummyJudgeEvaluator(LLMAsAJudgeBaseEvaluator): + def _build_single_user_prompt(self, params: dict[str, Any]) -> str: + return f"SINGLE: {params['x']}" + + def _build_pairwise_user_prompt(self, params: dict[str, Any]) -> str: + return f"PAIRWISE: {params['x']}" + + def _get_single_result_model(self) -> type[BaseModel]: + return DummyResult + + def _get_pairwise_result_model(self) -> type[BaseModel]: + return DummyResult + + def evaluate(self, sessions, query, state, reference=None) -> MetricState: # noqa: ANN001 + # Not needed for this unit test. + return MetricState(metric_name=MetricType("COHERENCE"), metric_value=True) + + +def test_llm_as_a_judge_uses_system_prompt_builder_and_message_list() -> None: + llm = MagicMock(spec=BaseChatModel) + + chain = MagicMock() + chain.invoke.return_value = DummyResult(score=1) + llm.with_structured_output.return_value = chain + + evaluator = DummyJudgeEvaluator(llm=llm) + + result = evaluator._invoke_single({"x": "hello"}) + assert isinstance(result, DummyResult) + assert result.score == 1 + + # Ensure we invoked the LLM with a list of messages. + args, _kwargs = chain.invoke.call_args + messages = args[0] + + assert isinstance(messages[0], SystemMessage) + assert "Role" in messages[0].content # comes from `introduction.j2` + assert isinstance(messages[1], HumanMessage) + assert messages[1].content == "SINGLE: hello" diff --git a/tests/test_memory_bank_dialogue_system_offline.py b/tests/test_memory_bank_dialogue_system_offline.py new file mode 100644 index 0000000..e17b94a --- /dev/null +++ b/tests/test_memory_bank_dialogue_system_offline.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.algorithms.summarize_algorithms.memory_bank.dialogue_system import ( + MemoryBankDialogueSystem, +) +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA + + +class FakeEmbeddings(Embeddings): + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed(t) for t in texts] + + def embed_query(self, text: str) -> list[float]: + return self._embed(text) + + @staticmethod + def _embed(text: str) -> list[float]: + t = text.lower() + return [1.0 if "kafka" in t else 0.0, 1.0 if "redis" in t else 0.0, 0.0] + + +@dataclass +class Capture: + invoked_messages: list[BaseMessage] | None = None + + +class CapturingChain: + def __init__(self, capture: Capture, response: Any) -> None: + self._capture = capture + self._response = response + + def invoke(self, messages: list[BaseMessage]) -> Any: + self._capture.invoked_messages = messages + return self._response + + +class FakeSessionSummarizer: + def summarize(self, session_messages: str, session_id: int): # noqa: ANN001 + _ = session_id + return [BaseBlock(role="SYSTEM", content=f"MEM<{session_messages}>")] + + +def test_memory_bank_pipeline_retrieves_and_injects_memory(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "offline") + + capture = Capture() + + from src.algorithms.summarize_algorithms.core.base_dialogue_system import ( + BaseDialogueSystem, + ) + + fake_llm = MagicMock(spec=BaseChatModel) + fake_llm.get_num_tokens_from_messages.return_value = 10 + + def _init(self, llm=None, is_local=False): # noqa: ANN001 + self.llm = fake_llm + self.memory_llm = fake_llm + + monkeypatch.setattr(BaseDialogueSystem, "_initialize_model", _init, raising=True) + + monkeypatch.setattr( + MemoryBankDialogueSystem, + "_build_summarizer", + lambda self: FakeSessionSummarizer(), + raising=True, + ) + + def _build_chain(self: ResponseGenerator): + return CapturingChain(capture, response={"plan_steps": []}) + + monkeypatch.setattr(ResponseGenerator, "_build_chain", _build_chain, raising=True) + + system = MemoryBankDialogueSystem(embed_model=FakeEmbeddings(), embed_code=False, embed_tool=False) + + # One session that will be summarized and stored. + sessions = [Session([BaseBlock(role="USER", content="Kafka is used here")])] + + # Run pipeline. + state = system.process_dialogue( + sessions=sessions, + system_prompt="Where is Kafka used?", + structure=PLAN_SCHEMA, + tools=[], + ) + + assert capture.invoked_messages is not None + assert state.prepared_messages == capture.invoked_messages + + system_msg = capture.invoked_messages[0] + assert isinstance(system_msg, SystemMessage) + assert "### MEMORY BANK:" in system_msg.content + + # Should contain our summarized memory content, since query mentions kafka. + assert "Kafka is used here" in system_msg.content + + assert isinstance(capture.invoked_messages[-1], HumanMessage) + assert capture.invoked_messages[-1].content == "Where is Kafka used?" diff --git a/tests/test_memory_logger.py b/tests/test_memory_logger.py index 4fd7456..079037c 100644 --- a/tests/test_memory_logger.py +++ b/tests/test_memory_logger.py @@ -1,15 +1,17 @@ import json +from pathlib import Path + import pytest -from src.benchmarking.memory_logger import MemoryLogger -from src.summarize_algorithms.core.models import ( +from src.algorithms.summarize_algorithms.core.models import ( BaseBlock, DialogueState, - MetricState, - MetricType, Session, ) +from src.benchmark.logger.memory_logger import MemoryLogger +from src.benchmark.models.dtos import MemoryRecord, MetricState +from src.benchmark.models.enums import MetricType class FakeStorage: @@ -27,6 +29,7 @@ def fake_state(): code_memory_storage=FakeStorage("code"), tool_memory_storage=FakeStorage("tool"), query="Test query", + prepared_messages=[] ) s._response = "Test response" s.text_memory = [["memory line 1", "memory line 2"]] @@ -40,7 +43,10 @@ def sessions(): def test_log_iteration_creates_file(tmp_path, fake_state, sessions): logger = MemoryLogger(logs_dir=tmp_path) - metric = MetricState(metric=MetricType.COHERENCE, value=0.87) + + metric_list = [MetricState(metric_name=MetricType("COHERENCE"), metric_value=0.87)] + + subdir = Path("test_runs") record = logger.log_iteration( system_name="FakeSystem", @@ -48,30 +54,46 @@ def test_log_iteration_creates_file(tmp_path, fake_state, sessions): iteration=1, sessions=sessions, state=fake_state, - metric=metric + metrics=metric_list, + subdirectory=subdir ) - assert isinstance(record, dict) - assert record["system"] == "FakeSystem" - assert record["iteration"] == 1 - assert record["query"] == "Hello?" - assert record["response"] == "Test response" - assert "metric_name" in record and abs(record["metric_value"] - 0.87) < 0.0001 + assert isinstance(record, MemoryRecord) + assert record.system == "FakeSystem" + assert record.iteration == 1 + assert record.query == "Hello?" + assert record.response == "Test response" + + assert record.metric is not None + assert len(record.metric) == 1 - expected_file = tmp_path / "FakeSystem1.jsonl" - assert expected_file.exists() + first_metric = record.metric[0] + if isinstance(first_metric, dict): + assert abs(first_metric["metric_value"] - 0.87) < 0.0001 + else: + assert abs(first_metric.metric_value - 0.87) < 0.0001 - content = expected_file.read_text(encoding="utf-8").strip() + expected_dir = tmp_path / subdir + assert expected_dir.exists() + + files = list(expected_dir.glob("FakeSystem-*.json")) + assert len(files) == 1 + log_file = files[0] + + content = log_file.read_text(encoding="utf-8").strip() parsed = json.loads(content) + assert parsed["system"] == "FakeSystem" assert parsed["query"] == "Hello?" assert "sessions" in parsed assert isinstance(parsed["sessions"], list) assert parsed["sessions"][0]["messages"][0]["content"] == "Hello there!" + assert parsed["metric"][0]["metric_name"] == "COHERENCE" def test_log_iteration_without_metric(tmp_path, fake_state, sessions): logger = MemoryLogger(logs_dir=tmp_path) + subdir = Path("no_metric_runs") record = logger.log_iteration( system_name="SystemNoMetric", @@ -79,11 +101,15 @@ def test_log_iteration_without_metric(tmp_path, fake_state, sessions): iteration=2, sessions=sessions, state=fake_state, - metric=None + metrics=None, + subdirectory=subdir ) - assert "metric_name" not in record - assert "metric_value" not in record + assert record.metric is None or record.metric == [] + + expected_dir = tmp_path / subdir + files = list(expected_dir.glob("SystemNoMetric-*.json")) + assert len(files) == 1 - expected_file = tmp_path / "SystemNoMetric2.jsonl" - assert expected_file.exists() + parsed = json.loads(files[0].read_text(encoding="utf-8")) + assert "metric" not in parsed or parsed["metric"] is None diff --git a/tests/test_memory_storage_offline.py b/tests/test_memory_storage_offline.py new file mode 100644 index 0000000..39a69ef --- /dev/null +++ b/tests/test_memory_storage_offline.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import pytest + +from langchain_core.embeddings import Embeddings + +from src.algorithms.summarize_algorithms.core.memory_storage.memory_storage import ( + MemoryStorage, +) + +# IMPORTANT: import models first to avoid circular import issues in `MemoryStorage`. +from src.algorithms.summarize_algorithms.core.models import BaseBlock + + +class FakeEmbeddings(Embeddings): + """Deterministic offline embeddings. + + Very small vectors where similarity is driven by keyword presence. + """ + + def embed_documents(self, texts: list[str]) -> list[list[float]]: # noqa: D401 + return [self._embed(t) for t in texts] + + def embed_query(self, text: str) -> list[float]: # noqa: D401 + return self._embed(text) + + @staticmethod + def _embed(text: str) -> list[float]: + t = text.lower() + return [ + 1.0 if "alpha" in t else 0.0, + 1.0 if "beta" in t else 0.0, + float(len(t) % 7) / 7.0, + ] + + +def test_memory_storage_add_and_find_similar_offline(monkeypatch: pytest.MonkeyPatch) -> None: + # MemoryStorage requires OPENAI_API_KEY even when custom embeddings are provided. + monkeypatch.setenv("OPENAI_API_KEY", "offline") + + storage = MemoryStorage(embeddings=FakeEmbeddings(), max_session_id=10) + + storage.add_memory( + [ + BaseBlock(role="SYSTEM", content="alpha memory"), + BaseBlock(role="SYSTEM", content="beta memory"), + ], + session_id=0, + ) + + results = storage.find_similar("alpha", top_k=1) + assert len(results) == 1 + assert "alpha" in results[0].content.lower() diff --git a/tests/test_prompt_parity.py b/tests/test_prompt_parity.py new file mode 100644 index 0000000..98d960d --- /dev/null +++ b/tests/test_prompt_parity.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage + +from src.algorithms.simple_algorithms.dialogue_baseline import DialogueBaseline +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA, TOOLS +from src.utils.system_prompt_builder import MemorySections + + +@dataclass +class InvocationCapture: + messages: list[BaseMessage] | None = None + + +class CapturingChain: + def __init__(self, capture: InvocationCapture, response: Any = "ok") -> None: + self._capture = capture + self._response = response + + def invoke(self, messages: list[BaseMessage]) -> Any: + self._capture.messages = messages + return self._response + + +@pytest.fixture +def fake_llm() -> MagicMock: + llm = MagicMock(spec=BaseChatModel) + # token counter is used by trimming/cropping + llm.get_num_tokens_from_messages.return_value = 10 + return llm + + +def _make_simple_baseline(fake_llm: MagicMock) -> DialogueBaseline: + # Avoid real model init + baseline = DialogueBaseline.__new__(DialogueBaseline) + baseline.system_name = "Baseline" + baseline.llm = fake_llm + + from src.utils.system_prompt_builder import SystemPromptBuilder + + baseline._prompt_builder = SystemPromptBuilder() + baseline.prompt_tokens = 0 + baseline.completion_tokens = 0 + baseline.total_cost = 0.0 + return baseline + + +def test_response_generator_system_prompt_matches_baseline(fake_llm: MagicMock) -> None: + """Baseline and ResponseGenerator must use the same unified system templates. + + This is the core contract: system prompt content must be identical when schema/memory are equivalent. + """ + # Baseline builds its system prompt with memory_mode="baseline" and schema + baseline = _make_simple_baseline(fake_llm) + baseline_system = baseline._prompt_builder.build( + schema=PLAN_SCHEMA, + tools=TOOLS, + memory=MemorySections(), + memory_mode="baseline", + ) + + # ResponseGenerator must infer baseline mode when memory is empty + rg = ResponseGenerator(fake_llm, structure=PLAN_SCHEMA, max_prompt_tokens=None) + system_message = rg._build_system_message(memory=MemorySections()) + + assert baseline_system == system_message.content + + +def test_baseline_and_response_generator_invoke_with_same_history(fake_llm: MagicMock) -> None: + """Both paths must pass `list[BaseMessage]` to LLM with the same ordering. + + We verify: + - first message is SystemMessage + - system content is identical + - the rest is the history (with the latest user query last) + """ + # Capture ResponseGenerator invocation + rg_capture = InvocationCapture() + rg = ResponseGenerator(fake_llm, structure=PLAN_SCHEMA, max_prompt_tokens=None) + rg._chain = CapturingChain(rg_capture) # type: ignore[assignment] + + from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session + + session = Session([BaseBlock(role="USER", content="Hi")]) + rg.generate_response( + last_session=session, + user_query="Question", + memory=MemorySections(), + ) + + assert rg_capture.messages is not None + assert isinstance(rg_capture.messages[0], SystemMessage) + assert rg_capture.messages[-1].content == "Question" + + +def test_memory_mode_inference_adds_memory_blocks(fake_llm: MagicMock) -> None: + """When memory sections are present, they must appear in system prompt.""" + rg = ResponseGenerator(fake_llm, structure=PLAN_SCHEMA, max_prompt_tokens=None) + + mem = MemorySections(recap="recap text") + system_message = rg._build_system_message(memory=mem) + + assert "### RECAP:" in system_message.content + assert "recap text" in system_message.content diff --git a/tests/test_recsum_dialogue_system_offline.py b/tests/test_recsum_dialogue_system_offline.py new file mode 100644 index 0000000..7baecb9 --- /dev/null +++ b/tests/test_recsum_dialogue_system_offline.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + +from src.algorithms.summarize_algorithms.core.models import BaseBlock, Session +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.algorithms.summarize_algorithms.recsum.dialogue_system import ( + RecsumDialogueSystem, +) +from src.benchmark.tool_plan_benchmarking.tools_and_schemas.parsed_jsons import PLAN_SCHEMA + + +@dataclass +class Capture: + invoked_messages: list[BaseMessage] | None = None + + +class CapturingChain: + def __init__(self, capture: Capture, response: Any) -> None: + self._capture = capture + self._response = response + + def invoke(self, messages: list[BaseMessage]) -> Any: + self._capture.invoked_messages = messages + return self._response + + +class FakeSummarizer: + def summarize(self, previous_memory: str, dialogue_context: str): # noqa: ANN001 + # Return BaseBlock objects, as expected by update_memory_node. + _ = previous_memory + return [BaseBlock(role="SYSTEM", content=f"RECSUM<{dialogue_context}>")] + + +def test_recsum_pipeline_builds_prompt_with_recap_and_invokes_llm(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "offline") + + capture = Capture() + + # Patch BaseDialogueSystem model initialization to avoid network / API keys. + from src.algorithms.summarize_algorithms.core.base_dialogue_system import ( + BaseDialogueSystem, + ) + + def _fake_init_model(self, llm=None, is_local=False): # noqa: ANN001 + self.llm = llm or pytest.MonkeyPatch().context # type: ignore[attr-defined] + + # easier: assign MagicMocks + import unittest.mock + + fake_llm = unittest.mock.MagicMock(spec=BaseChatModel) + fake_llm.get_num_tokens_from_messages.return_value = 10 + + def _init(self, llm=None, is_local=False): # noqa: ANN001 + self.llm = fake_llm + self.memory_llm = fake_llm + + monkeypatch.setattr(BaseDialogueSystem, "_initialize_model", _init, raising=True) + + # Patch summarizer builder + monkeypatch.setattr(RecsumDialogueSystem, "_build_summarizer", lambda self: FakeSummarizer(), raising=True) + + # Patch ResponseGenerator to use capturing chain + def _build_chain(self: ResponseGenerator): + return CapturingChain(capture, response={"plan_steps": []}) + + monkeypatch.setattr(ResponseGenerator, "_build_chain", _build_chain, raising=True) + + system = RecsumDialogueSystem(is_local=True) + + sessions = [Session([BaseBlock(role="USER", content="hello")])] + state = system.process_dialogue( + sessions=sessions, + system_prompt="What now?", + structure=PLAN_SCHEMA, + tools=[], + ) + + assert capture.invoked_messages is not None + assert state.prepared_messages == capture.invoked_messages + + # System message must contain recap + assert isinstance(capture.invoked_messages[0], SystemMessage) + assert "### RECAP:" in capture.invoked_messages[0].content + + # Latest query must be last human message + assert isinstance(capture.invoked_messages[-1], HumanMessage) + assert capture.invoked_messages[-1].content == "What now?" diff --git a/tests/test_response_generator.py b/tests/test_response_generator.py index c7ba7f9..d5ff3f0 100644 --- a/tests/test_response_generator.py +++ b/tests/test_response_generator.py @@ -3,91 +3,100 @@ import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.prompts import PromptTemplate +from langchain_core.messages import HumanMessage, SystemMessage -from src.summarize_algorithms.core.response_generator import ResponseGenerator +from src.algorithms.summarize_algorithms.core.models import ( + BaseBlock, + ResponseContext, + Session, +) +from src.algorithms.summarize_algorithms.core.response_generator import ( + ResponseGenerator, +) +from src.utils.system_prompt_builder import MemorySections @pytest.fixture def mock_llm(): - return create_autospec(BaseChatModel) + llm = create_autospec(BaseChatModel) + llm.get_num_tokens_from_messages.return_value = 10 + return llm @pytest.fixture -def mock_prompt_template(): - return create_autospec(PromptTemplate) +def response_generator(mock_llm): + return ResponseGenerator(llm=mock_llm) @pytest.fixture -def response_generator(mock_llm, mock_prompt_template): - return ResponseGenerator(llm=mock_llm, prompt_template=mock_prompt_template) +def empty_session(): + return Session([]) -def test_initialization(response_generator, mock_llm, mock_prompt_template): - assert response_generator.llm is mock_llm - assert response_generator.prompt_template is mock_prompt_template - assert hasattr(response_generator, "chain") +def test_initialization(response_generator, mock_llm): + assert response_generator._llm is mock_llm + assert hasattr(response_generator, "_chain") -def test_generate_response_success(response_generator): +def test_generate_response_success(response_generator, empty_session): mock_chain = MagicMock() mock_chain.invoke.return_value = "Test response" - response_generator.chain = mock_chain + response_generator._chain = mock_chain + + last_session = Session([BaseBlock(role="USER", content="Hi")]) result = response_generator.generate_response( - dialogue_memory="Memory content", - code_memory="Code memory", - tool_memory="Tool memory", - query="User question", + last_session=last_session, + user_query="User question", + memory=MemorySections(recap="Some memory"), ) - assert result == "Test response" - mock_chain.invoke.assert_called_once_with( - { - "dialogue_memory": "Memory content", - "code_memory": "Code memory", - "tool_memory": "Tool memory", - "query": "User question", - } - ) + assert isinstance(result, ResponseContext) + assert result.response == "Test response" + assert isinstance(result.prepared_history, list) + mock_chain.invoke.assert_called_once() + history = mock_chain.invoke.call_args[0][0] -@pytest.mark.parametrize( - "dialogue_memory,code_memory,tool_memory,query", - [ - ("", "", "", ""), - ("Short", "Short Code", "Short tool", "Simple query"), - ("Very long memory " * 20, "Short Code", "Medium Tool" * 5, "Complex?" * 10), - ("Memory", "", "", "Query"), - ], -) -def test_argument_combinations( - dialogue_memory, code_memory, tool_memory, query, response_generator -): - mock_chain = MagicMock() - mock_chain.invoke.return_value = "Response" - response_generator.chain = mock_chain + assert isinstance(history, list) - response_generator.generate_response( - dialogue_memory, code_memory, tool_memory, query - ) - mock_chain.invoke.assert_called_once_with( - { - "dialogue_memory": dialogue_memory, - "code_memory": code_memory, - "tool_memory": tool_memory, - "query": query, - } - ) + assert isinstance(history[0], SystemMessage) + assert "The System Instruction ends here" in str(history[0].content) + + assert isinstance(history[-1], HumanMessage) + assert history[-1].content == "User question" -def test_generate_response_exception(response_generator): +def test_generate_response_exception(response_generator, empty_session): mock_chain = MagicMock() mock_chain.invoke.side_effect = Exception("Network error") - response_generator.chain = mock_chain + response_generator._chain = mock_chain with pytest.raises(ConnectionError) as exc_info: - response_generator.generate_response("dmem", "cmem", "tmem", "q") + response_generator.generate_response( + last_session=empty_session, + user_query="q", + memory=MemorySections(), + ) assert "API request failed: Network error" in str(exc_info.value) assert isinstance(exc_info.value.__cause__, Exception) + + +def test_history_structure(response_generator): + mock_chain = MagicMock() + mock_chain.invoke.return_value = "resp" + response_generator._chain = mock_chain + + last_ses = Session([BaseBlock(role="ASSISTANT", content="Prev answer")]) + + result = response_generator.generate_response( + last_session=last_ses, + user_query="New query", + memory=MemorySections(recap="Memory info"), + ) + + history = result.prepared_history + + assert "Memory info" in str(history[0].content) + assert history[-1].content == "New query" diff --git a/tests/test_statistics_pipeline_parity.py b/tests/test_statistics_pipeline_parity.py new file mode 100644 index 0000000..59ddeb4 --- /dev/null +++ b/tests/test_statistics_pipeline_parity.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal + +from src.benchmark.models.dtos import BaseRecord, MetricState +from src.benchmark.models.enums import MetricType +from src.benchmark.tool_plan_benchmarking.statistics.aggregator import ( + MetricStatisticsAggregator, +) +from src.benchmark.tool_plan_benchmarking.statistics.observations_collector import ( + MetricObservationsCollector, +) +from src.benchmark.tool_plan_benchmarking.statistics.statistics import Statistics + + +def _make_record(system: str, sessions_count: int, f1: float, strict: float) -> BaseRecord: + sessions = [{"messages": []} for _ in range(sessions_count)] + return BaseRecord( + timestamp=datetime.now().isoformat(), + iteration=1, + system=system, + query="q", + response={}, + sessions=sessions, + prepared_messages=[], + metric=[ + MetricState(metric_name=MetricType.F1_TOOL, metric_value=Decimal(str(f1))), + MetricState(metric_name=MetricType.F1_TOOL_STRICT, metric_value=Decimal(str(strict))), + ], + ) + + +def test_statistics_facade_matches_new_pipeline() -> None: + records = [ + _make_record("A", 3, 0.1, 0.2), + _make_record("A", 3, 0.3, 0.4), + _make_record("B", 3, 0.5, 0.6), + ] + + # New pipeline + obs = MetricObservationsCollector.collect(records) + stats_new = MetricStatisticsAggregator(normalize=False).aggregate(obs) + + # Old facade API + stats_facade = Statistics.calculate_by_logs( + count_of_launches=999, + metrics=records, + normalize=False, + ) + + assert stats_new.algorithms == stats_facade.algorithms + + +def test_normalization_is_stable_per_metric_type() -> None: + records = [ + _make_record("A", 3, 0.0, 1.0), + _make_record("B", 3, 1.0, 0.0), + ] + + obs = MetricObservationsCollector.collect(records) + stats_norm = MetricStatisticsAggregator(normalize=True).aggregate(obs) + + # For each metric type we should get values normalized to [0, 1] + for alg in stats_norm.algorithms: + for run in alg.runs: + assert 0.0 <= run.value <= 1.0 diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index 7cc5438..1b123a1 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -3,7 +3,7 @@ import pytest -from src.summarize_algorithms.recsum.summarizer import RecursiveSummarizer +from src.algorithms.summarize_algorithms.recsum.summarizer import RecursiveSummarizer @dataclass diff --git a/tests/test_tool_plan_benchmarking_pipeline.py b/tests/test_tool_plan_benchmarking_pipeline.py new file mode 100644 index 0000000..2a8c961 --- /dev/null +++ b/tests/test_tool_plan_benchmarking_pipeline.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +from src.benchmark.logger.baseline_logger import BaselineLogger +from src.benchmark.models.dtos import BaseRecord +from src.benchmark.tool_plan_benchmarking.calculator import Calculator +from src.benchmark.tool_plan_benchmarking.evaluators.f1_tool_evaluator import ( + F1ToolEvaluator, +) + + +class FakeAlgo: + system_name = "FakeAlgorithm" + + def process_dialogue(self, sessions, system_prompt, structure=None, tools=None): # noqa: ANN001 + # Return a state-like object with the minimal surface required by evaluator+logger. + state = MagicMock() + state.response = { + "plan_steps": [ + {"kind": "tool_call", "name": "read_file", "args": {}, "id": "s1", "description": "", "depends_on": []} + ] + } + state.prepared_messages = [] + return state + + +def test_calculator_end_to_end_logs_metrics(tmp_path: Path) -> None: + algo = FakeAlgo() + evaluator_simple = F1ToolEvaluator(mode="simple") + logger = BaselineLogger(logs_dir=tmp_path) + + sessions = [] + reference = [] + + records = Calculator.evaluate( + algorithms=[algo], + evaluator_functions=[evaluator_simple], + sessions=sessions, + prompt="q", + reference=reference, + logger=logger, + subdirectory=Path("sub"), + tools=[], + iteration=1, + ) + + assert len(records) == 1 + assert isinstance(records[0], BaseRecord) + assert records[0].system == "FakeAlgorithm" + + # logger should have written a json file + # Note: Calculator passes `algorithm.system_name / subdirectory` to the logger. + written = list((tmp_path / "FakeAlgorithm" / "sub").glob("FakeAlgorithm-*.json")) + assert len(written) == 1 + + # Now append a *new* metric into the same JSON file. + evaluator_strict = F1ToolEvaluator(mode="strict") + + updated = Calculator.evaluate_by_logs( + algorithms=[algo], + evaluator_functions=[evaluator_strict], + reference=reference, + logger=logger, + logs_path=tmp_path, + subdirectory=Path("sub"), + iteration=1, + ) + + assert len(updated) == 1 + assert updated[0].metric is not None + metric_names = {m.metric_name.value for m in updated[0].metric} + assert "F1_TOOL_STRICT" in metric_names