From 665bcbb7d66c0fa8af0a6c54db63f63f9ceb1c4d Mon Sep 17 00:00:00 2001 From: Steve Wall Date: Mon, 6 Oct 2025 16:12:41 -0600 Subject: [PATCH 1/2] Add LLM fallback capability. --- .env.example | 15 + .gitignore | 10 +- AGENTS.md | 15 + Makefile | 4 + README.md | 18 + activities/tool_activities.py | 67 ++- dev-tools/README.md | 3 + dev-tools/allow-anthropic.sh | 124 ++++++ dev-tools/block-anthropic.sh | 130 ++++++ docs/adding-goals-and-tools.md | 30 +- docs/architecture-decisions.md | 11 + docs/architecture.md | 12 + docs/contributing.md | 7 + docs/setup.md | 58 +++ docs/testing.md | 20 + frontend/package-lock.json | 6 +- goals/travel.py | 2 +- scripts/run_worker.py | 14 +- shared/llm_manager.py | 291 +++++++++++++ tests/README.md | 40 +- tests/test_agent_goal_workflow.py | 56 ++- ...test_agent_goal_workflow_execute_prompt.py | 175 ++++++++ ...est_agent_goal_workflow_validate_prompt.py | 167 ++++++++ tests/test_llm_manager.py | 381 ++++++++++++++++++ tests/test_mcp_integration.py | 24 +- tests/test_tool_activities.py | 96 +++-- .../workflowtests/agent_goal_workflow_test.py | 15 +- workflows/agent_goal_workflow.py | 108 +++-- workflows/workflow_helpers.py | 2 +- 29 files changed, 1757 insertions(+), 144 deletions(-) create mode 100644 dev-tools/README.md create mode 100755 dev-tools/allow-anthropic.sh create mode 100755 dev-tools/block-anthropic.sh create mode 100644 shared/llm_manager.py create mode 100644 tests/test_agent_goal_workflow_execute_prompt.py create mode 100644 tests/test_agent_goal_workflow_validate_prompt.py create mode 100644 tests/test_llm_manager.py diff --git a/.env.example b/.env.example index bb7be2a..717f738 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,13 @@ # Example environment configuration ### LLM configuration +# Output LLM interaction to a file. +LLM_DEBUG_OUTPUT=false +# Primary LLM +#LLM_TIMEOUT_SECONDS=10 +#LLM_MODEL=ollama/gemma3 +#LLM_KEY=no-key-needed +#LLM_BASE_URL=http://localhost:11434 LLM_MODEL=openai/gpt-4o LLM_KEY=sk-proj-... # LLM_MODEL=anthropic/claude-3-5-sonnet-20240620 @@ -8,6 +15,14 @@ LLM_KEY=sk-proj-... # LLM_MODEL=gemini/gemini-2.5-flash-preview-04-17 # LLM_KEY=${GOOGLE_API_KEY} +# Fallback LLM +#LLM_FALLBACK_TIMEOUT_SECONDS=10 +#LLM_FALLBACK_MODEL=ollama/gemma3 +#LLM_FALLBACK_KEY=no-key-needed +#LLM_FALLBACK_BASE_URL=http://localhost:11434 +LLM_FALLBACK_MODEL=openai/gpt-4o +LLM_FALLBACK_KEY=sk-proj-... + ### Tool API keys # RAPIDAPI_KEY=9df2cb5... # Optional - if unset flight search generates realistic mock data # RAPIDAPI_HOST_FLIGHTS=sky-scrapper.p.rapidapi.com # For real travel flight information (optional) diff --git a/.gitignore b/.gitignore index 10fc80c..456c094 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # OS-specific files .DS_Store +# Debug files +debug_llm_calls + # Python cache & compiled files __pycache__/ *.py[cod] @@ -30,9 +33,14 @@ coverage.xml # PyCharm / IntelliJ settings .idea/ +*.iml .env .env* # Cursor -.cursor \ No newline at end of file +.cursor + +# Claude Code +CLAUDE.md +.claude diff --git a/AGENTS.md b/AGENTS.md index a993fa3..02bdc46 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,20 @@ # Temporal AI Agent Contribution Guide +## Table of Contents +- [Repository Layout](#repository-layout) +- [Running the Application](#running-the-application) + - [Quick Start with Docker](#quick-start-with-docker) + - [Local Development Setup](#local-development-setup) + - [Environment Configuration](#environment-configuration) +- [Testing](#testing) +- [Linting and Code Quality](#linting-and-code-quality) +- [Agent Customization](#agent-customization) + - [Adding New Goals and Tools](#adding-new-goals-and-tools) + - [Configuring Goals](#configuring-goals) +- [Architecture](#architecture) +- [Commit Messages and Pull Requests](#commit-messages-and-pull-requests) +- [Additional Resources](#additional-resources) + ## Repository Layout - `workflows/` - Temporal workflows including the main AgentGoalWorkflow for multi-turn AI conversations - `activities/` - Temporal activities for tool execution and LLM interactions diff --git a/Makefile b/Makefile index 07c5f69..91fae88 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,9 @@ setup: run-worker: uv run scripts/run_worker.py +run-worker-debug: + LOGLEVEL=DEBUG PYTHONUNBUFFERED=1 uv run scripts/run_worker.py + run-api: uv run uvicorn api.main:app --reload @@ -41,6 +44,7 @@ help: @echo "Available commands:" @echo " make setup - Install all dependencies" @echo " make run-worker - Start the Temporal worker" + @echo " make run-worker-debug - Start the Temporal worker with DEBUG logging" @echo " make run-api - Start the API server" @echo " make run-frontend - Start the frontend development server" @echo " make run-train-api - Start the train API server" diff --git a/README.md b/README.md index ca8a3b8..149af03 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,21 @@ # Temporal AI Agent +## Table of Contents +- [Overview](#overview) +- [Demo Videos](#demo-videos) +- [Why Temporal?](#why-temporal) +- [What is "Agentic AI"?](#what-is-agentic-ai) +- [MCP Tool Calling Support](#-mcp-tool-calling-support) +- [Setup and Configuration](#setup-and-configuration) +- [Customizing Interaction & Tools](#customizing-interaction--tools) +- [Architecture](#architecture) +- [Testing](#testing) +- [Development](#development) +- [Productionalization & Adding Features](#productionalization--adding-features) +- [Enablement Guide](#enablement-guide-internal-resource-for-temporal-employees) + +## Overview + This demo shows a multi-turn conversation with an AI agent running inside a Temporal workflow. The purpose of the agent is to collect information towards a goal, running tools along the way. The agent supports both native tools and Model Context Protocol (MCP) tools, allowing it to interact with external services. The agent operates in single-agent mode by default, focusing on one specific goal. It also supports experimental multi-agent/multi-goal mode where users can choose between different agent types and switch between them during conversations. @@ -14,6 +30,8 @@ The AI will respond with clarifications and ask for any missing information to t - Ollama models (local) - And many more! +## Demo Videos + It's really helpful to [watch the demo (5 minute YouTube video)](https://www.youtube.com/watch?v=GEXllEH2XiQ) to understand how interaction works. [![Watch the demo](./assets/agent-youtube-screenshot.jpeg)](https://www.youtube.com/watch?v=GEXllEH2XiQ) diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 1380666..4107400 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -19,6 +19,7 @@ ValidationResult, ) from models.tool_definitions import MCPServerDefinition +from shared.llm_manager import LLMManager from shared.mcp_client_manager import MCPClientManager # Import MCP client libraries @@ -36,20 +37,24 @@ class ToolActivities: def __init__(self, mcp_client_manager: MCPClientManager = None): - """Initialize LLM client using LiteLLM and optional MCP client manager""" + """Initialize LLM client using LLMManager with fallback support and optional MCP client manager""" + # Use LLMManager for automatic fallback support + self.llm_manager = LLMManager() + + # Keep legacy attributes for backward compatibility self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4") self.llm_key = os.environ.get("LLM_KEY") self.llm_base_url = os.environ.get("LLM_BASE_URL") + self.mcp_client_manager = mcp_client_manager - print(f"Initializing ToolActivities with LLM model: {self.llm_model}") - if self.llm_base_url: - print(f"Using custom base URL: {self.llm_base_url}") + print(f"Initializing ToolActivities with LLMManager") if self.mcp_client_manager: print("MCP client manager enabled for connection pooling") + @activity.defn - async def agent_validatePrompt( - self, validation_input: ValidationInput + async def agent_validate_prompt( + self, validation_input: ValidationInput, fallback_mode: bool ) -> ValidationResult: """ Validates the prompt in the context of the conversation history and agent goal. @@ -101,15 +106,16 @@ async def agent_validatePrompt( prompt=validation_prompt, context_instructions=context_instructions ) - result = await self.agent_toolPlanner(prompt_input) + result = await self.agent_tool_planner(prompt_input, fallback_mode) return ValidationResult( validationResult=result.get("validationResult", False), validationFailedReason=result.get("validationFailedReason", {}), ) + @activity.defn - async def agent_toolPlanner(self, input: ToolPromptInput) -> dict: + async def agent_tool_planner(self, input: ToolPromptInput, fallback_mode: bool) -> dict: messages = [ { "role": "system", @@ -124,17 +130,7 @@ async def agent_toolPlanner(self, input: ToolPromptInput) -> dict: ] try: - completion_kwargs = { - "model": self.llm_model, - "messages": messages, - "api_key": self.llm_key, - } - - # Add base_url if configured - if self.llm_base_url: - completion_kwargs["base_url"] = self.llm_base_url - - response = completion(**completion_kwargs) + response = await self.llm_manager.call_llm(messages, fallback_mode) response_content = response.choices[0].message.content activity.logger.info(f"Raw LLM response: {repr(response_content)}") @@ -205,6 +201,39 @@ async def get_wf_env_vars(self, input: EnvLookupInput) -> EnvLookupOutput: return output + def warm_up_ollama(self) -> bool: + """ + Pre-load the Ollama model to avoid cold start latency. + Returns True if successful, False otherwise. + """ + import time + + try: + start_time = time.time() + print("Sending warm-up request to Ollama...") + + # Make a simple completion request to load the model + response = completion( + model=self.llm_model, + messages=[{"role": "user", "content": "Hello"}], + api_key=self.llm_key, + base_url=self.llm_base_url, + ) + + end_time = time.time() + duration = end_time - start_time + + if response and response.choices: + print(f"✅ Model loaded successfully in {duration:.1f} seconds") + return True + else: + print("❌ Model loading failed: No response received") + return False + + except Exception as e: + print(f"❌ Model loading failed: {str(e)}") + return False + @activity.defn async def mcp_tool_activity( self, tool_name: str, tool_args: Dict[str, Any] diff --git a/dev-tools/README.md b/dev-tools/README.md new file mode 100644 index 0000000..f7c78f3 --- /dev/null +++ b/dev-tools/README.md @@ -0,0 +1,3 @@ +# Developer Tools + +This directory contains tools useful during development. \ No newline at end of file diff --git a/dev-tools/allow-anthropic.sh b/dev-tools/allow-anthropic.sh new file mode 100755 index 0000000..704986c --- /dev/null +++ b/dev-tools/allow-anthropic.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash +set -euo pipefail + +HOST="${HOST:-api.anthropic.com}" +ANCHOR_NAME="anthropic" +ANCHOR_FILE="/etc/pf.anchors/${ANCHOR_NAME}" +PF_CONF="/etc/pf.conf" + +require_root() { + if [ "${EUID:-$(id -u)}" -ne 0 ]; then + echo "Please run with sudo." >&2 + exit 1 + fi +} + +backup_once() { + local file="$1" + if [ -f "$file" ] && [ ! -f "${file}.bak" ]; then + cp -p "$file" "${file}.bak" + fi +} + +ensure_anchors_dir() { + if [ ! -d "/etc/pf.anchors" ]; then + mkdir -p /etc/pf.anchors + chmod 755 /etc/pf.anchors + fi +} + +ensure_anchor_hook() { + if ! grep -qE '^\s*anchor\s+"'"${ANCHOR_NAME}"'"' "$PF_CONF"; then + echo "Wiring anchor into ${PF_CONF}..." + backup_once "$PF_CONF" + { + echo '' + echo '# --- Begin anthropic anchor hook ---' + echo 'anchor "'"${ANCHOR_NAME}"'"' + echo 'load anchor "'"${ANCHOR_NAME}"'" from "/etc/pf.anchors/'"${ANCHOR_NAME}"'"' + echo '# --- End anthropic anchor hook ---' + } >> "$PF_CONF" + fi +} + +default_iface() { + route -n get default 2>/dev/null | awk '/interface:/{print $2; exit}' +} + +resolve_ips() { + (dig +short A "$HOST"; dig +short AAAA "$HOST") 2>/dev/null \ + | awk 'NF' | sort -u +} + +write_anchor_allow() { + local iface="$1"; shift + local ips=("$@") + + local table_entries="" + if [ "${#ips[@]}" -gt 0 ]; then + for ip in "${ips[@]}"; do + if [ -n "$ip" ]; then + if [ -z "$table_entries" ]; then + table_entries="$ip" + else + table_entries="$table_entries, $ip" + fi + fi + done + fi + + backup_once "$ANCHOR_FILE" + { + echo "# ${ANCHOR_FILE}" + echo "# Auto-generated: $(date)" + echo "# Host: ${HOST}" + echo "table persist { ${table_entries} }" + echo "" + echo "# Allow outbound traffic to Anthropic" + echo "pass out quick on ${iface} to " + } > "$ANCHOR_FILE" +} + +enable_pf() { + pfctl -E >/dev/null 2>&1 || true +} + +reload_pf() { + if ! pfctl -nf "$PF_CONF" >/dev/null 2>&1; then + echo "pf.conf validation failed. Aborting." >&2 + exit 1 + fi + pfctl -f "$PF_CONF" >/dev/null +} + +main() { + require_root + ensure_anchors_dir + + local iface + iface="$(default_iface || true)" + if [ -z "${iface:-}" ]; then + echo "Could not determine default network interface." >&2 + exit 1 + fi + + ensure_anchor_hook + + ips=() + while IFS= read -r ip; do + ips+=("$ip") + done < <(resolve_ips) + + if [ "${#ips[@]}" -eq 0 ]; then + echo "Warning: No IPs resolved for ${HOST}. The table will be empty." >&2 + fi + + write_anchor_allow "$iface" "${ips[@]}" + enable_pf + reload_pf + + echo "✅ Anthropic API is now ALLOWED via pf on interface ${iface}." + echo "Anchor file: ${ANCHOR_FILE}" +} + +main "$@" diff --git a/dev-tools/block-anthropic.sh b/dev-tools/block-anthropic.sh new file mode 100755 index 0000000..16810d0 --- /dev/null +++ b/dev-tools/block-anthropic.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash +set -euo pipefail + +HOST="${HOST:-api.anthropic.com}" +ANCHOR_NAME="anthropic" +ANCHOR_FILE="/etc/pf.anchors/${ANCHOR_NAME}" +PF_CONF="/etc/pf.conf" + +require_root() { + if [ "${EUID:-$(id -u)}" -ne 0 ]; then + echo "Please run with sudo." >&2 + exit 1 + fi +} + +backup_once() { + local file="$1" + # Only back up if the source exists and a .bak doesn't already exist + if [ -f "$file" ] && [ ! -f "${file}.bak" ]; then + cp -p "$file" "${file}.bak" + fi +} + +ensure_anchors_dir() { + if [ ! -d "/etc/pf.anchors" ]; then + mkdir -p /etc/pf.anchors + chmod 755 /etc/pf.anchors + fi +} + +ensure_anchor_hook() { + # Add an anchor include to pf.conf if it's not already there. + if ! grep -qE '^\s*anchor\s+"'"${ANCHOR_NAME}"'"' "$PF_CONF"; then + echo "Wiring anchor into ${PF_CONF}..." + backup_once "$PF_CONF" + { + echo '' + echo '# --- Begin anthropic anchor hook ---' + echo 'anchor "'"${ANCHOR_NAME}"'"' + echo 'load anchor "'"${ANCHOR_NAME}"'" from "/etc/pf.anchors/'"${ANCHOR_NAME}"'"' + echo '# --- End anthropic anchor hook ---' + } >> "$PF_CONF" + fi +} + +default_iface() { + route -n get default 2>/dev/null | awk '/interface:/{print $2; exit}' +} + +resolve_ips() { + # Resolve both A and AAAA; dedupe; ignore blanks + (dig +short A "$HOST"; dig +short AAAA "$HOST") 2>/dev/null \ + | awk 'NF' | sort -u +} + +write_anchor_block() { + local iface="$1"; shift + local ips=("$@") + + # Build table entries + local table_entries="" + if [ "${#ips[@]}" -gt 0 ]; then + for ip in "${ips[@]}"; do + if [ -n "$ip" ]; then + if [ -z "$table_entries" ]; then + table_entries="$ip" + else + table_entries="$table_entries, $ip" + fi + fi + done + fi + + backup_once "$ANCHOR_FILE" + { + echo "# ${ANCHOR_FILE}" + echo "# Auto-generated: $(date)" + echo "# Host: ${HOST}" + echo "table persist { ${table_entries} }" + echo "" + echo "# Block outbound traffic to Anthropic" + echo "block drop out quick on ${iface} to " + } > "$ANCHOR_FILE" +} + +enable_pf() { + pfctl -E >/dev/null 2>&1 || true # enable silently if disabled +} + +reload_pf() { + # Validate before applying + if ! pfctl -nf "$PF_CONF" >/dev/null 2>&1; then + echo "pf.conf validation failed. Aborting." >&2 + exit 1 + fi + pfctl -f "$PF_CONF" >/dev/null +} + +main() { + require_root + ensure_anchors_dir + + local iface + iface="$(default_iface || true)" + if [ -z "${iface:-}" ]; then + echo "Could not determine default network interface." >&2 + exit 1 + fi + + ensure_anchor_hook + + # Collect IPs without 'mapfile' (macOS bash 3.2 friendly) + ips=() + while IFS= read -r ip; do + ips+=("$ip") + done < <(resolve_ips) + + if [ "${#ips[@]}" -eq 0 ]; then + echo "Warning: No IPs resolved for ${HOST}. The table will be empty." >&2 + fi + + write_anchor_block "$iface" "${ips[@]}" + enable_pf + reload_pf + + echo "✅ Anthropic API is now BLOCKED via pf on interface ${iface}." + echo "Anchor file: ${ANCHOR_FILE}" +} + +main "$@" diff --git a/docs/adding-goals-and-tools.md b/docs/adding-goals-and-tools.md index 69ff99b..4a6a4a8 100644 --- a/docs/adding-goals-and-tools.md +++ b/docs/adding-goals-and-tools.md @@ -1,5 +1,27 @@ # Customizing the Agent -The agent operates in single-agent mode by default, focusing on one specific goal. It also supports an experimental multi-agent mode where users can have multiple agents, each with their own goal, and supports switching back to choosing a new goal at the end of every successful goal (or even mid-goal). + +## Table of Contents +- [Adding a New Goal Category](#adding-a-new-goal-category) +- [Adding a Goal](#adding-a-goal) +- [Adding Native Tools](#adding-native-tools) + - [Note on Optional Tools](#note-on-optional-tools) + - [Add to Tool Registry](#add-to-tool-registry) + - [Create Each Native Tool Implementation](#create-each-native-tool-implementation) + - [Add to tools/__init__.py and the tool get_handler()](#add-to-tools__init__py-and-the-tool-get_handler) + - [Update workflow_helpers.py](#update-workflow_helperspy) +- [Adding MCP Tools](#adding-mcp-tools) + - [Configure MCP Server Definition](#configure-mcp-server-definition) + - [Using Predefined Configurations](#using-predefined-configurations) + - [Custom MCP Server Definition](#custom-mcp-server-definition) + - [MCP Tool Configuration](#mcp-tool-configuration) + - [How MCP Tools Work](#how-mcp-tools-work) +- [Tool Confirmation](#tool-confirmation) +- [Add a Goal & Tools Checklist](#add-a-goal--tools-checklist) + - [All Goals](#all-goals) + - [Native Tools](#native-tools) + - [MCP Tools](#mcp-tools) + +The agent operates in single-agent mode by default, focusing on one specific goal. It also supports an experimental multi-agent mode where users can have multiple agents, each with their own goal, and supports switching back to choosing a new goal at the end of every successful goal (or even mid-goal). A goal can use two types of tools: - **Native Tools**: Custom tools implemented directly in the codebase (in `/tools/`) @@ -155,19 +177,19 @@ I recommend exploring all three. For a demo, I would decide if you want the Argu ## Add a Goal & Tools Checklist -### For All Goals: +### All Goals - [ ] Create goal file in `/goals/` directory (e.g., `goals/my_category.py`) - [ ] Add goal to the category's goal list in the file - [ ] Import and extend the goal list in `goals/__init__.py` - [ ] If a new category, add Goal Category to [.env](./.env) and [.env.example](./.env.example) -### For Native Tools: +### Native Tools - [ ] Add native tools to [tool_registry.py](tools/tool_registry.py) - [ ] Implement tool functions in `/tools/` directory - [ ] Add tools to [tools/__init__.py](tools/__init__.py) in the `get_handler()` function - [ ] Add tool names to static tools list in [workflows/workflow_helpers.py](workflows/workflow_helpers.py) -### For MCP Tools: +### MCP Tools - [ ] Add `mcp_server_definition` to your goal configuration (use `shared/mcp_config.py` for common servers) - [ ] Ensure MCP server is available and properly configured - [ ] Set required environment variables (API keys, etc.) diff --git a/docs/architecture-decisions.md b/docs/architecture-decisions.md index 7c146d1..b1ae53e 100644 --- a/docs/architecture-decisions.md +++ b/docs/architecture-decisions.md @@ -1,4 +1,15 @@ # Architecture Decisions + +## Table of Contents +- [AI Models](#ai-models) +- [Temporal](#temporal) + - [Reliability and State Management](#reliability-and-state-management) + - [Handling Complex, Dynamic Workflows](#handling-complex-dynamic-workflows) + - [Scalability and Speed](#scalability-and-speed) + - [Observability and Debugging](#observability-and-debugging) + - [Simplified Error Handling](#simplified-error-handling) + - [Flexibility for Experimentation](#flexibility-for-experimentation) + This documents some of the "why" behind the [architecture](./architecture.md). ## AI Models diff --git a/docs/architecture.md b/docs/architecture.md index 745b060..2bc437a 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -1,4 +1,16 @@ # Elements + +## Table of Contents +- [Workflow](#workflow) + - [Workflow Responsibilities](#workflow-responsibilities) +- [Activities](#activities) +- [Tools](#tools) +- [Prompts](#prompts) +- [LLM](#llm) +- [Interaction](#interaction) +- [Architecture Model](#architecture-model) +- [Adding Features](#adding-features) + These are the main elements of this system. See [architecture decisions](./architecture-decisions.md) for information beind these choices. In this document we will explain each element and their interactions, and then connect them all at the end. Architecture Elements diff --git a/docs/contributing.md b/docs/contributing.md index c4d8c8f..fe2f80a 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,5 +1,12 @@ # Contributing to the Temporal AI Agent Project +## Table of Contents +- [Getting Started](#getting-started) + - [Code Style & Formatting](#code-style--formatting) + - [Linting & Type Checking](#linting--type-checking) +- [Testing](#testing) +- [Making Changes](#making-changes) + This document provides guidelines for contributing to `temporal-ai-agent`. All setup and installation instructions can be found in [setup.md](./setup.md). ## Getting Started diff --git a/docs/setup.md b/docs/setup.md index b92a237..3270d05 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -1,4 +1,32 @@ # Setup Guide + +## Table of Contents + +- [Initial Configuration](#initial-configuration) + - [Quick Start with Makefile](#quick-start-with-makefile) + - [Manual Setup (Alternative to Makefile)](#manual-setup-alternative-to-makefile) + - [Agent Goal Configuration](#agent-goal-configuration) + - [LLM Configuration](#llm-configuration) + - [Fallback LLM Configuration](#fallback-llm-configuration) +- [Configuring Temporal Connection](#configuring-temporal-connection) + - [Use Temporal Cloud](#use-temporal-cloud) + - [Use a local Temporal Dev Server](#use-a-local-temporal-dev-server) +- [Running the Application](#running-the-application) + - [Docker](#docker) + - [Local Machine (no docker)](#local-machine-no-docker) +- [MCP Tools Configuration](#mcp-tools-configuration) + - [Adding MCP Tools to Goals](#adding-mcp-tools-to-goals) + - [MCP Environment Variables](#mcp-environment-variables) +- [Goal-Specific Tool Configuration](#goal-specific-tool-configuration) + - [Goal: Find an event in Australia / New Zealand, book flights to it and invoice the user for the cost](#goal-find-an-event-in-australia--new-zealand-book-flights-to-it-and-invoice-the-user-for-the-cost) + - [Goal: Find a Premier League match, book train tickets to it and invoice the user for the cost (Replay 2025 Keynote)](#goal-find-a-premier-league-match-book-train-tickets-to-it-and-invoice-the-user-for-the-cost-replay-2025-keynote) + - [Goals: FIN - Money Movement and Loan Application](#goals-fin---money-movement-and-loan-application) + - [Goals: HR/PTO](#goals-hrpto) + - [Goals: Ecommerce](#goals-ecommerce) + - [Goal: Food Ordering with MCP Integration (Stripe Payment Processing)](#goal-food-ordering-with-mcp-integration-stripe-payment-processing) +- [Customizing the Agent Further](#customizing-the-agent-further) +- [Setup Checklist](#setup-checklist) + ## Initial Configuration This application uses `.env` files for configuration. Copy the [.env.example](.env.example) file to `.env` and update the values: @@ -83,6 +111,7 @@ The agent uses LiteLLM to interact with various LLM providers. Configure the fol - `LLM_MODEL`: The model to use (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet", "google/gemini-pro", etc.) - `LLM_KEY`: Your API key for the selected provider +- `LLM_TIMEOUT_SECONDS`: (Optional) Request timeout in seconds. - `LLM_BASE_URL`: (Optional) Custom base URL for the LLM provider. Useful for: - Using Ollama with a custom endpoint - Using a proxy or custom API gateway @@ -111,6 +140,35 @@ LLM_BASE_URL=http://localhost:11434 For a complete list of supported models and providers, visit the [LiteLLM documentation](https://docs.litellm.ai/docs/providers). +#### Fallback LLM Configuration + +The system includes automatic fallback functionality to improve reliability when the primary LLM becomes unavailable. The LLM Manager provides transparent failover with automatic recovery detection. + +Configure fallback LLM settings in your `.env` file: + +```bash +# Fallback LLM Configuration +LLM_FALLBACK_MODEL=openai/gpt-4o-mini # Fallback model (often a cheaper/faster option) +LLM_FALLBACK_KEY=sk-proj-fallback-key... # API key for fallback LLM +LLM_FALLBACK_BASE_URL=... # Optional custom endpoint for fallback +LLM_FALLBACK_TIMEOUT_SECONDS=10 # Timeout for fallback LLM calls (default: 10) +``` + +##### Debug enabled for development + +Enable debugging to monitor LLM behavior and troubleshoot issues: + +```bash +# Debug Settings +LLM_DEBUG_OUTPUT=true # Enable debug file output (default: false) +LLM_DEBUG_OUTPUT_DIR=./debug_llm_calls # Debug output directory (default: ./debug_llm_calls) +``` + +#### Troubleshooting + +1. **Both LLMs failing**: Check API keys and network connectivity +2. **Timeout errors**: Increase timeout values or check network latency + ## Configuring Temporal Connection By default, this application will connect to a local Temporal server (`localhost:7233`) in the default namespace, using the `agent-task-queue` task queue. You can override these settings in your `.env` file. diff --git a/docs/testing.md b/docs/testing.md index 6a68054..e7a279f 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -1,5 +1,25 @@ # Testing the Temporal AI Agent +## Table of Contents +- [Quick Start](#quick-start) +- [Test Categories](#test-categories) + - [Unit Tests](#unit-tests) + - [Integration Tests](#integration-tests) +- [Running Specific Tests](#running-specific-tests) +- [Test Environment Options](#test-environment-options) + - [Local Environment (Default)](#local-environment-default) + - [Time-Skipping Environment (Recommended for CI)](#time-skipping-environment-recommended-for-ci) + - [External Temporal Server](#external-temporal-server) +- [Environment Variables](#environment-variables) +- [Test Coverage](#test-coverage) +- [Test Output](#test-output) +- [Troubleshooting](#troubleshooting) + - [Common Issues](#common-issues) + - [Debugging Tests](#debugging-tests) +- [Continuous Integration](#continuous-integration) +- [Additional Resources](#additional-resources) +- [Test Architecture](#test-architecture) + This guide provides instructions for running the comprehensive test suite for the Temporal AI Agent project. ## Quick Start diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 5bc88a3..e4b0220 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1295,9 +1295,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001690", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001690.tgz", - "integrity": "sha512-5ExiE3qQN6oF8Clf8ifIDcMRCRE/dMGcETG/XGMD8/XiXm6HXQgQTh1yZYLXXpSOsEUlJm1Xr7kGULZTuGtP/w==", + "version": "1.0.30001743", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001743.tgz", + "integrity": "sha512-e6Ojr7RV14Un7dz6ASD0aZDmQPT/A+eZU+nuTNfjqmRrmkmQlnTNWH0SKmqagx9PeW87UVqapSurtAXifmtdmw==", "funding": [ { "type": "opencollective", diff --git a/goals/travel.py b/goals/travel.py index 11d8052..fa64b6e 100644 --- a/goals/travel.py +++ b/goals/travel.py @@ -80,7 +80,7 @@ "agent: Let's search for flights around these dates. Could you provide your departure city?", "user: San Francisco", "agent: Thanks, searching for flights from San Francisco to Sydney around 2023-02-25 to 2023-02-28.", - "user_confirmed_tool_run: " + "user_confirmed_tool_run: ", 'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}', "agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?", "user_confirmed_tool_run: ", diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 5ba1c89..ce79c93 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -56,7 +56,15 @@ async def main(): print("===========================================================\n") print("Worker ready to process tasks!") - logging.basicConfig(level=logging.INFO) + + # Configure logging level from environment or default to INFO + log_level = os.environ.get("LOGLEVEL", "INFO").upper() + numeric_level = getattr(logging, log_level, logging.INFO) + logging.basicConfig( + level=numeric_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + print(f"Logging configured at level: {log_level}") # Run the worker with proper cleanup try: @@ -68,8 +76,8 @@ async def main(): task_queue=TEMPORAL_TASK_QUEUE, workflows=[AgentGoalWorkflow], activities=[ - activities.agent_validatePrompt, - activities.agent_toolPlanner, + activities.agent_validate_prompt, + activities.agent_tool_planner, activities.get_wf_env_vars, activities.mcp_tool_activity, dynamic_tool_activity, diff --git a/shared/llm_manager.py b/shared/llm_manager.py new file mode 100644 index 0000000..d06875c --- /dev/null +++ b/shared/llm_manager.py @@ -0,0 +1,291 @@ +""" +LLM Manager with automatic fallback support using Temporal activity heartbeat for state persistence. + +Environment Variables: + LLM_MODEL: Primary LLM model (e.g., "openai/gpt-4") + LLM_KEY: API key for primary LLM + LLM_BASE_URL: Optional custom base URL for primary LLM + LLM_TIMEOUT_SECONDS: Timeout for primary LLM calls in seconds (default: 10) + LLM_FALLBACK_MODEL: Fallback LLM model + LLM_FALLBACK_KEY: API key for fallback LLM + LLM_FALLBACK_BASE_URL: Optional custom base URL for fallback LLM + LLM_FALLBACK_TIMEOUT_SECONDS: Timeout for fallback LLM calls in seconds (default: 10) + LLM_DEBUG_OUTPUT: Enable debug file output ("true"/"false", default: "false") + LLM_DEBUG_OUTPUT_DIR: Directory for debug files (default: "./debug_llm_calls") + +Usage: + manager = LLMManager() + response = await manager.call_llm([{"role": "user", "content": "Hello"}]) + +""" +import asyncio +import os +import json +from datetime import datetime +from typing import Any, Dict, List + +from dotenv import load_dotenv +from litellm import completion +from temporalio import activity + +load_dotenv(override=True) + + +class LLMManager: + """ + Manages LLM calls with intelligence to use the primary LLM or fallback. + """ + + def __init__(self): + """Initialize LLM Manager with primary and fallback configurations.""" + # Primary LLM configuration + self.primary_model = os.environ.get("LLM_MODEL", "openai/gpt-4") + self.primary_key = os.environ.get("LLM_KEY") + self.primary_base_url = os.environ.get("LLM_BASE_URL") + self.primary_timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", "10")) + + # Fallback LLM configuration + self.fallback_model = os.environ.get("LLM_FALLBACK_MODEL") + self.fallback_key = os.environ.get("LLM_FALLBACK_KEY") + self.fallback_base_url = os.environ.get("LLM_FALLBACK_BASE_URL") + self.fallback_timeout_seconds = int( + os.environ.get("LLM_FALLBACK_TIMEOUT_SECONDS", "10") + ) + + # Debug file settings + self.debug_output_enabled = ( + os.environ.get("LLM_DEBUG_OUTPUT", "false").lower() == "true" + ) + self.debug_output_dir = os.environ.get( + "LLM_DEBUG_OUTPUT_DIR", "./debug_llm_calls" + ) + self._log_configuration() + + def _log_configuration(self): + """Log the LLM configuration for debugging.""" + print(f"[LLMManager._log_configuration] LLM Manager initialized:") + print(f" Primary model: {self.primary_model}") + print(f" Primary API key: {'***set***' if self.primary_key else 'not set'}") + print(f" Primary base URL: {self.primary_base_url or 'default'}") + + if self.fallback_model: + print(f" Fallback model: {self.fallback_model}") + print( + f" Fallback API key: {'***set***' if self.fallback_key else 'not set'}" + ) + print(f" Fallback base URL: {self.fallback_base_url or 'default'}") + else: + print(f" No fallback model configured") + + if self.debug_output_enabled: + print(f" Debug output enabled: {self.debug_output_dir}") + else: + print(f" Debug output disabled") + + print(f" Initial state: using_fallback=False, primary_failure_time=None") + + async def call_llm(self, messages: List[Dict[str, str]], fallback_mode: bool) -> Dict[str, Any]: + """ + Call LLM with automatic fallback support. + + Args: + messages: The messages to send to the LLM + + Returns: + The LLM response + + Raises: + Exception: If LLM call fails + """ + + await self._handle_debug(messages) + + # Determine which LLM + if fallback_mode: + response = await self._call_fallback_llm(messages) + else: + # Try primary LLM and throw exception on failure + response = await self._call_primary_llm_strict(messages) + + return response + + async def _call_primary_llm_strict( + self, messages: List[Dict[str, str]] + ) -> Dict[str, Any]: + """ + Call the primary LLM and throw exception on failure. + + :param messages: LLM messages + :return: LLM response + :raises: Exception if primary LLM fails + """ + activity.logger.debug( + f"[LLMManager._call_primary_llm_strict] Attempting primary LLM call: {self.primary_model}" + ) + completion_kwargs = { + "model": self.primary_model, + "messages": messages, + "api_key": self.primary_key, + } + + if self.primary_base_url: + completion_kwargs["base_url"] = self.primary_base_url + + response = completion( + **completion_kwargs, timeout=self.primary_timeout_seconds + ) + activity.logger.info(f"Primary LLM call successful: {self.primary_model}") + + return response + + + async def _call_fallback_llm( + self, messages: List[Dict[str, str]] + ) -> Dict[str, Any]: + """Call the fallback LLM.""" + if not self.fallback_model: + raise Exception("No fallback model configured") + + activity.logger.info(f"Using fallback LLM: {self.fallback_model}") + + completion_kwargs = { + "model": self.fallback_model, + "messages": messages, + "api_key": self.fallback_key, + } + + if self.fallback_base_url: + completion_kwargs["base_url"] = self.fallback_base_url + + try: + response = completion( + **completion_kwargs, timeout=self.fallback_timeout_seconds + ) + activity.logger.info(f"Fallback LLM call successful: {self.fallback_model}") + return response + except Exception as fallback_error: + activity.logger.error(f"Fallback LLM also failed: {str(fallback_error)}") + raise Exception( + f"Fallback LLM failed. " + f"Primary: {self.primary_model}, Fallback: {self.fallback_model}" + ) + + async def _handle_debug(self, messages: List[Dict[str, str]]): + """ + Handle debug output if enabled. + """ + activity.logger.debug( + f"[LLMManager.call_llm] Starting LLM call with {len(messages)} messages" + ) + + # Save debug output if enabled + if self.debug_output_enabled: + await self._save_debug_output(messages) + + async def _save_debug_output(self, messages: List[Dict[str, str]]) -> None: + """ + Save LLM messages in a format that can be cut/pasted into an LLM interface. + """ + activity.logger.debug( + f"[LLMManager._save_debug_output] Starting debug output save" + ) + try: + # Create debug directory if it doesn't exist + os.makedirs(self.debug_output_dir, exist_ok=True) + + # Clean up old files, keeping only the 20 most recent + self._cleanup_old_debug_files() + + # Generate timestamp-based filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[ + :-3 + ] # Include milliseconds + filename = f"llm_call_{timestamp}.txt" + filepath = os.path.join(self.debug_output_dir, filename) + activity.logger.debug( + f"[LLMManager._save_debug_output] Writing debug output to: {filepath}" + ) + + # Write to file + with open(filepath, "w") as f: + # Write header information + f.write(f"=== LLM Debug Output ===\n") + f.write(f"Timestamp: {datetime.now().isoformat()}\n") + f.write("=" * 50 + "\n\n") + + # Write each message in a readable format + activity.logger.debug( + f"[LLMManager._save_debug_output] Writing {len(messages)} messages to debug file" + ) + for i, message in enumerate(messages, 1): + role = message.get("role", "unknown") + content = message.get("content", "") + + f.write(f"As ({role.upper()}) :\n") + f.write(f"{content}\n") + f.write("\n" + "-" * 30 + "\n\n") + + # Add a section for easy copying + f.write("=== FOR MANUAL TESTING ===\n") + f.write("Copy the messages above and paste into your LLM interface.\n") + + activity.logger.debug(f"Saved LLM debug output to {filepath}") + except Exception as e: + activity.logger.warning(f"Failed to save LLM debug output: {str(e)}") + + def _cleanup_old_debug_files(self) -> None: + """Keep only the 20 most recent debug files, delete older ones.""" + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Starting cleanup of old debug files" + ) + try: + # Get all debug files in the directory + debug_files = [] + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Scanning directory: {self.debug_output_dir}" + ) + for filename in os.listdir(self.debug_output_dir): + if filename.startswith("llm_call_") and filename.endswith(".txt"): + filepath = os.path.join(self.debug_output_dir, filename) + if os.path.isfile(filepath): + # Get file modification time + mtime = os.path.getmtime(filepath) + debug_files.append((filepath, mtime)) + + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Found {len(debug_files)} debug files" + ) + + # Sort by modification time (newest first) + debug_files.sort(key=lambda x: x[1], reverse=True) + + # Keep only the 20 most recent files, delete the rest + if len(debug_files) > 20: + files_to_delete = debug_files[20:] + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Need to delete {len(files_to_delete)} old files" + ) + for filepath, _ in files_to_delete: + try: + os.remove(filepath) + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Deleted old debug file: {filepath}" + ) + activity.logger.debug(f"Deleted old debug file: {filepath}") + except OSError as e: + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Failed to delete {filepath}: {str(e)}" + ) + activity.logger.warning( + f"Failed to delete old debug file {filepath}: {str(e)}" + ) + else: + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] No cleanup needed, {len(debug_files)} files <= 20 limit" + ) + + except Exception as e: + activity.logger.debug( + f"[LLMManager._cleanup_old_debug_files] Cleanup failed: {type(e).__name__}: {str(e)}" + ) + activity.logger.warning(f"Failed to cleanup old debug files: {str(e)}") diff --git a/tests/README.md b/tests/README.md index 95599ce..9bdb470 100644 --- a/tests/README.md +++ b/tests/README.md @@ -2,6 +2,44 @@ This directory contains comprehensive tests for the Temporal AI Agent project. The tests cover workflows, activities, and integration scenarios using Temporal's testing framework. +## Table of Contents + +- [Test Structure](#test-structure) +- [Test Types](#test-types) + - [1. Workflow Tests (`test_agent_goal_workflow.py`)](#1-workflow-tests-test_agent_goal_workflowpy) + - [2. Activity Tests (`test_tool_activities.py`)](#2-activity-tests-test_tool_activitiespy) + - [3. Configuration Tests (`conftest.py`)](#3-configuration-tests-conftestpy) +- [Running Tests](#running-tests) + - [Prerequisites](#prerequisites) + - [Basic Test Execution](#basic-test-execution) + - [Test Environment Options](#test-environment-options) + - [Filtering Tests](#filtering-tests) +- [Test Configuration](#test-configuration) + - [Test Discovery](#test-discovery) + - [Environment Variables](#environment-variables) + - [Mocking Strategy](#mocking-strategy) +- [Writing New Tests](#writing-new-tests) + - [Test Naming Convention](#test-naming-convention) + - [Using Fixtures](#using-fixtures) + - [Mocking External Dependencies](#mocking-external-dependencies) + - [Testing Workflow Signals and Queries](#testing-workflow-signals-and-queries) +- [Test Data and Fixtures](#test-data-and-fixtures) + - [Sample Agent Goal](#sample-agent-goal) + - [Sample Conversation History](#sample-conversation-history) + - [Sample Combined Input](#sample-combined-input) +- [Debugging Tests](#debugging-tests) + - [Verbose Logging](#verbose-logging) + - [Temporal Web UI](#temporal-web-ui) + - [Test Isolation](#test-isolation) +- [Continuous Integration](#continuous-integration) + - [GitHub Actions Example](#github-actions-example) + - [Test Coverage](#test-coverage) +- [Best Practices](#best-practices) +- [Troubleshooting](#troubleshooting) + - [Common Issues](#common-issues) + - [Getting Help](#getting-help) +- [Legacy Tests](#legacy-tests) + ## Test Structure ``` @@ -32,7 +70,7 @@ Tests the main `AgentGoalWorkflow` class covering: Tests the `ToolActivities` class and `dynamic_tool_activity` function: - **LLM Integration**: Testing agent_toolPlanner with mocked LLM responses -- **Validation Logic**: Testing agent_validatePrompt with various scenarios +- **Validation Logic**: Testing agent_validate_prompt with various scenarios - **Environment Configuration**: Testing get_wf_env_vars with different env setups - **JSON Processing**: Testing response parsing and sanitization - **Dynamic Tool Execution**: Testing the dynamic activity dispatcher diff --git a/tests/test_agent_goal_workflow.py b/tests/test_agent_goal_workflow.py index 6826eb0..3582990 100644 --- a/tests/test_agent_goal_workflow.py +++ b/tests/test_agent_goal_workflow.py @@ -74,14 +74,15 @@ async def test_user_prompt_signal( async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_agent_validatePrompt( + @activity.defn(name="agent_validate_prompt") + async def mock_agent_validate_prompt( validation_input: ValidationInput, + fallback_mode: bool, ) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_agent_tool_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: return {"next": "done", "response": "Test response from LLM"} async with Worker( @@ -90,8 +91,8 @@ async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: workflows=[AgentGoalWorkflow], activities=[ mock_get_wf_env_vars, - mock_agent_validatePrompt, - mock_agent_toolPlanner, + mock_agent_validate_prompt, + mock_agent_tool_planner, ], ): handle = await client.start_workflow( @@ -139,14 +140,15 @@ async def test_confirm_signal( async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_agent_validatePrompt( + @activity.defn(name="agent_validate_prompt") + async def mock_agent_validate_prompt( validation_input: ValidationInput, + fallback_mode: bool, ) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_agent_tool_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: return { "next": "confirm", "tool": "TestTool", @@ -164,8 +166,8 @@ async def mock_test_tool(args: dict) -> dict: workflows=[AgentGoalWorkflow], activities=[ mock_get_wf_env_vars, - mock_agent_validatePrompt, - mock_agent_toolPlanner, + mock_agent_validate_prompt, + mock_agent_tool_planner, mock_test_tool, ], ): @@ -207,9 +209,10 @@ async def test_validation_failure( async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_agent_validatePrompt( + @activity.defn(name="agent_validate_prompt") + async def mock_agent_validate_prompt( validation_input: ValidationInput, + fallback_mode: bool, ) -> ValidationResult: return ValidationResult( validationResult=False, @@ -223,7 +226,7 @@ async def mock_agent_validatePrompt( client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], - activities=[mock_get_wf_env_vars, mock_agent_validatePrompt], + activities=[mock_get_wf_env_vars, mock_agent_validate_prompt], ): handle = await client.start_workflow( AgentGoalWorkflow.run, @@ -480,14 +483,15 @@ async def test_multiple_user_prompts( async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_agent_validatePrompt( + @activity.defn(name="agent_validate_prompt") + async def mock_agent_validate_prompt( validation_input: ValidationInput, + fallback_mode: bool, ) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_agent_tool_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: # Keep workflow running for multiple prompts return {"next": "question", "response": f"Processed: {input.prompt}"} @@ -497,8 +501,8 @@ async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: workflows=[AgentGoalWorkflow], activities=[ mock_get_wf_env_vars, - mock_agent_validatePrompt, - mock_agent_toolPlanner, + mock_agent_validate_prompt, + mock_agent_tool_planner, ], ): handle = await client.start_workflow( @@ -541,3 +545,13 @@ async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: # Verify at least the first message was processed message_texts = [str(msg["response"]) for msg in user_messages] assert any("First message" in text for text in message_texts) + + + + + + + + + + diff --git a/tests/test_agent_goal_workflow_execute_prompt.py b/tests/test_agent_goal_workflow_execute_prompt.py new file mode 100644 index 0000000..f242085 --- /dev/null +++ b/tests/test_agent_goal_workflow_execute_prompt.py @@ -0,0 +1,175 @@ +import pytest + +from models.data_types import ToolPromptInput +from models.tool_definitions import AgentGoal, ToolDefinition, ToolArgument +import workflows.agent_goal_workflow as agw_module + + +@pytest.mark.asyncio +async def test__execute_prompt_success(monkeypatch): + """ + Test that _execute_prompt calls execute_activity_method as expected + """ + # Arrange: create workflow instance and set a concrete goal + wf = agw_module.AgentGoalWorkflow() + wf.fallback_mode = False + wf.goal = AgentGoal( + id="unit_goal_id", + category_tag="unit_cat", + agent_name="UnitAgent", + agent_friendly_description="Unit test agent", + description="Unit test goal description", + tools=[ + ToolDefinition( + name="UnitTool", + description="Unit tool", + arguments=[ToolArgument(name="param", type="string", description="p")], + ) + ], + ) + + # Capture container + captured = {"activity_called": False} + + # Create a minimal workflow mock exposing only execute_activity_method + class WorkflowMock: + async def execute_activity_method(self, activity, *, args, schedule_to_close_timeout, start_to_close_timeout, retry_policy, summary): + captured["activity_called"] = True + + # Validate args structure and values + assert isinstance(args, list) + assert len(args) == 2 + prompt_input, fallback_mode = args + assert isinstance(prompt_input, ToolPromptInput) + assert prompt_input.prompt == "test prompt" + assert prompt_input.context_instructions == "test context" + assert fallback_mode is False + + # Validate timeouts and retry policy were forwarded + assert schedule_to_close_timeout == agw_module.LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT + assert start_to_close_timeout == agw_module.LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT + # Retry policy values as defined in workflow implementation + assert retry_policy.initial_interval.total_seconds() == 5 + assert retry_policy.backoff_coefficient == 1 + assert retry_policy.maximum_attempts == 2 + + # Summary should be empty when not in fallback + assert summary == "" + + # Return a successful tool_data response + return { + "next": "confirm", + "tool": "UnitTool", + "args": {"param": "value"}, + "response": "Tool response" + } + + # Monkeypatch the module-level `workflow` object to our minimal mock + monkeypatch.setattr(agw_module, "workflow", WorkflowMock(), raising=True) + + # Create prompt input + prompt_input = ToolPromptInput( + prompt="test prompt", + context_instructions="test context" + ) + + # Act + result = await wf._execute_prompt(prompt_input) + + # Assert + assert captured["activity_called"] is True + assert isinstance(result, dict) + assert result["next"] == "confirm" + assert result["tool"] == "UnitTool" + assert result["args"] == {"param": "value"} + assert result["response"] == "Tool response" + + +@pytest.mark.asyncio +async def test__execute_prompt_activityerror_triggers_fallback(monkeypatch): + """ + Test that _execute_prompt handles the ActivityError by calling the fallback. + """ + # Arrange + wf = agw_module.AgentGoalWorkflow() + wf.fallback_mode = False + wf.goal = AgentGoal( + id="unit_goal_id_2", + category_tag="unit_cat", + agent_name="UnitAgent", + agent_friendly_description="Unit test agent", + description="Unit test goal description", + tools=[ + ToolDefinition( + name="UnitTool", + description="Unit tool", + arguments=[ToolArgument(name="param", type="string", description="p")], + ) + ], + ) + + calls = {"count": 0} + + from temporalio.exceptions import ActivityError + from temporalio.api.enums.v1 import RetryState + + class WorkflowMock: + class _Logger: + def info(self, *args, **kwargs): + return None + + logger = _Logger() + async def execute_activity_method(self, activity, *, args, schedule_to_close_timeout, start_to_close_timeout, retry_policy, summary): + calls["count"] += 1 + prompt_input, fallback_mode = args + if calls["count"] == 1: + # First attempt should be non-fallback and raise ActivityError + assert fallback_mode is False + assert summary == "" + raise ActivityError( + message="primary failure", + scheduled_event_id=1, + started_event_id=2, + identity="unit-test", + activity_type="agent_tool_planner", + activity_id="activity-1", + retry_state=RetryState.RETRY_STATE_MAXIMUM_ATTEMPTS_REACHED, + ) + + # Second attempt should be fallback + assert fallback_mode is True + assert summary == "fallback" + assert prompt_input.prompt == "test prompt" + assert prompt_input.context_instructions == "test context" + + # Validate timeouts and retry policy forwarded + assert schedule_to_close_timeout == agw_module.LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT + assert start_to_close_timeout == agw_module.LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT + assert retry_policy.initial_interval.total_seconds() == 5 + assert retry_policy.backoff_coefficient == 1 + assert retry_policy.maximum_attempts == 2 + + return { + "next": "confirm", + "tool": "UnitTool", + "args": {"param": "value"}, + "response": "Fallback tool response" + } + + monkeypatch.setattr(agw_module, "workflow", WorkflowMock(), raising=True) + + # Create prompt input + prompt_input = ToolPromptInput( + prompt="test prompt", + context_instructions="test context" + ) + + # Act + result = await wf._execute_prompt(prompt_input) + + # Assert + assert isinstance(result, dict) + assert result["next"] == "confirm" + assert result["response"] == "Fallback tool response" + assert wf.fallback_mode is True + assert calls["count"] == 2 \ No newline at end of file diff --git a/tests/test_agent_goal_workflow_validate_prompt.py b/tests/test_agent_goal_workflow_validate_prompt.py new file mode 100644 index 0000000..54f085e --- /dev/null +++ b/tests/test_agent_goal_workflow_validate_prompt.py @@ -0,0 +1,167 @@ +import pytest + +from models.data_types import ValidationInput, ValidationResult +from models.tool_definitions import AgentGoal, ToolDefinition, ToolArgument +import workflows.agent_goal_workflow as agw_module + + +@pytest.mark.asyncio +async def test__validate_prompt_success(monkeypatch): + """ + Test that _validate_prompt calls execute_activity_method as expected + """ + # Arrange: create workflow instance and set a concrete goal + wf = agw_module.AgentGoalWorkflow() + wf.fallback_mode = False + wf.goal = AgentGoal( + id="unit_goal_id", + category_tag="unit_cat", + agent_name="UnitAgent", + agent_friendly_description="Unit test agent", + description="Unit test goal description", + tools=[ + ToolDefinition( + name="UnitTool", + description="Unit tool", + arguments=[ToolArgument(name="param", type="string", description="p")], + ) + ], + ) + + # Capture container + captured = {"add_message_called": False, "activity_called": False} + + # Mock add_message on the instance + def fake_add_message(actor, response): + captured["add_message_called"] = True + assert actor == "user" + assert response == "unit_prompt" + + monkeypatch.setattr(wf, "add_message", fake_add_message, raising=True) + + # Create a minimal workflow mock exposing only execute_activity_method + class WorkflowMock: + async def execute_activity_method(self, activity, *, args, schedule_to_close_timeout, start_to_close_timeout, retry_policy, summary): + captured["activity_called"] = True + + # Validate args structure and values + assert isinstance(args, list) + assert len(args) == 2 + validation_input, fallback_mode = args + assert isinstance(validation_input, ValidationInput) + assert validation_input.prompt == "unit_prompt" + assert validation_input.agent_goal.id == "unit_goal_id" + assert isinstance(validation_input.conversation_history, dict) + assert "messages" in validation_input.conversation_history + assert fallback_mode is False + + # Validate timeouts and retry policy were forwarded + assert schedule_to_close_timeout == agw_module.LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT + assert start_to_close_timeout == agw_module.LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT + # Retry policy values as defined in workflow implementation + assert retry_policy.initial_interval.total_seconds() == 5 + assert retry_policy.backoff_coefficient == 1 + assert retry_policy.maximum_attempts == 2 + + # Summary should be empty when not in fallback + assert summary == "" + + # Return a successful ValidationResult + return ValidationResult(validationResult=True, validationFailedReason={}) + + # Monkeypatch the module-level `workflow` object to our minimal mock + monkeypatch.setattr(agw_module, "workflow", WorkflowMock(), raising=True) + + # Act + result = await wf._validate_prompt("unit_prompt") + + # Assert + assert captured["add_message_called"] is True + assert captured["activity_called"] is True + assert isinstance(result, ValidationResult) + assert result.validationResult is True + + +@pytest.mark.asyncio +async def test__validate_prompt_activityerror_triggers_fallback(monkeypatch): + """ + Test that _validate_prompt handles the ActivityError by calling the fallback. + """ + # Arrange + wf = agw_module.AgentGoalWorkflow() + wf.fallback_mode = False + wf.goal = AgentGoal( + id="unit_goal_id_2", + category_tag="unit_cat", + agent_name="UnitAgent", + agent_friendly_description="Unit test agent", + description="Unit test goal description", + tools=[ + ToolDefinition( + name="UnitTool", + description="Unit tool", + arguments=[ToolArgument(name="param", type="string", description="p")], + ) + ], + ) + + calls = {"count": 0} + + # Ensure add_message still works but we don't assert twice here + def fake_add_message(actor, response): + assert actor == "user" + assert response == "needs_fallback" + + monkeypatch.setattr(wf, "add_message", fake_add_message, raising=True) + + from temporalio.exceptions import ActivityError + from temporalio.api.enums.v1 import RetryState + + class WorkflowMock: + class _Logger: + def info(self, *args, **kwargs): + return None + + logger = _Logger() + async def execute_activity_method(self, activity, *, args, schedule_to_close_timeout, start_to_close_timeout, retry_policy, summary): + calls["count"] += 1 + validation_input, fallback_mode = args + if calls["count"] == 1: + # First attempt should be non-fallback and raise ActivityError + assert fallback_mode is False + assert summary == "" + raise ActivityError( + message="primary failure", + scheduled_event_id=1, + started_event_id=2, + identity="unit-test", + activity_type="agent_validate_prompt", + activity_id="activity-1", + retry_state=RetryState.RETRY_STATE_MAXIMUM_ATTEMPTS_REACHED, + ) + + # Second attempt should be fallback + assert fallback_mode is True + assert summary == "fallback" + assert validation_input.prompt == "needs_fallback" + assert validation_input.agent_goal.id == "unit_goal_id_2" + + # Validate timeouts and retry policy forwarded + assert schedule_to_close_timeout == agw_module.LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT + assert start_to_close_timeout == agw_module.LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT + assert retry_policy.initial_interval.total_seconds() == 5 + assert retry_policy.backoff_coefficient == 1 + assert retry_policy.maximum_attempts == 2 + + return ValidationResult(validationResult=True, validationFailedReason={}) + + monkeypatch.setattr(agw_module, "workflow", WorkflowMock(), raising=True) + + # Act + result = await wf._validate_prompt("needs_fallback") + + # Assert + assert isinstance(result, ValidationResult) + assert result.validationResult is True + assert wf.fallback_mode is True + assert calls["count"] == 2 diff --git a/tests/test_llm_manager.py b/tests/test_llm_manager.py new file mode 100644 index 0000000..8eeedbc --- /dev/null +++ b/tests/test_llm_manager.py @@ -0,0 +1,381 @@ +"""Tests for LLM Manager with fallback support.""" +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from shared.llm_manager import LLMManager + + +class TestLLMManagerConfiguration: + """Test cases for LLMManager configuration.""" + + def setup_method(self): + """Set up test environment for each test.""" + # Clear any existing environment variables + env_vars_to_clear = [ + "LLM_MODEL", + "LLM_KEY", + "LLM_BASE_URL", + "LLM_TIMEOUT_SECONDS", + "LLM_FALLBACK_MODEL", + "LLM_FALLBACK_KEY", + "LLM_FALLBACK_BASE_URL", + "LLM_FALLBACK_TIMEOUT_SECONDS", + "LLM_RECOVERY_CHECK_INTERVAL_SECONDS", + ] + for var in env_vars_to_clear: + os.environ.pop(var, None) + + def test_initialization_defaults(self): + """Test default initialization values.""" + with patch.dict(os.environ, {}, clear=True): + manager = LLMManager() + + assert manager.primary_model == "openai/gpt-4" + assert manager.primary_key is None + assert manager.primary_base_url is None + assert manager.primary_timeout_seconds == 10 + assert manager.fallback_model is None + assert manager.fallback_key is None + assert manager.fallback_base_url is None + assert manager.fallback_timeout_seconds == 10 + + def test_initialization_with_environment_variables(self): + """Test initialization with custom environment variables.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "anthropic/claude-3-5-sonnet-20241022", + "LLM_KEY": "primary-key-123", + "LLM_BASE_URL": "https://primary.api.com", + "LLM_TIMEOUT_SECONDS": "30", + "LLM_FALLBACK_MODEL": "openai/gpt-4o", + "LLM_FALLBACK_KEY": "fallback-key-456", + "LLM_FALLBACK_BASE_URL": "https://fallback.api.com", + "LLM_FALLBACK_TIMEOUT_SECONDS": "20", + }, + ): + manager = LLMManager() + + assert manager.primary_model == "anthropic/claude-3-5-sonnet-20241022" + assert manager.primary_key == "primary-key-123" + assert manager.primary_base_url == "https://primary.api.com" + assert manager.primary_timeout_seconds == 30 + assert manager.fallback_model == "openai/gpt-4o" + assert manager.fallback_key == "fallback-key-456" + assert manager.fallback_base_url == "https://fallback.api.com" + assert manager.fallback_timeout_seconds == 20 + + +class TestLLMManagerPrimaryLLM: + """Test cases for primary LLM functionality.""" + + def setup_method(self): + """Set up test environment for each test.""" + env_vars_to_clear = [ + "LLM_MODEL", + "LLM_KEY", + "LLM_BASE_URL", + "LLM_TIMEOUT_SECONDS", + "LLM_FALLBACK_MODEL", + "LLM_FALLBACK_KEY", + "LLM_FALLBACK_BASE_URL", + "LLM_FALLBACK_TIMEOUT_SECONDS", + ] + for var in env_vars_to_clear: + os.environ.pop(var, None) + + @pytest.mark.asyncio + async def test_primary_llm_success(self): + """Test successful primary LLM call.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + }, + ): + manager = LLMManager() + + # Mock the completion function + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + + with patch("shared.llm_manager.completion", return_value=mock_response): + messages = [{"role": "user", "content": "Hello"}] + response = await manager.call_llm(messages, fallback_mode=False) + + assert response == mock_response + + @pytest.mark.asyncio + async def test_primary_llm_with_custom_base_url(self): + """Test primary LLM call with custom base URL.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_BASE_URL": "https://custom.api.com", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [{"role": "user", "content": "Hello"}] + await manager.call_llm(messages, fallback_mode=False) + + # Verify completion was called with base_url + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["base_url"] == "https://custom.api.com" + assert call_kwargs["model"] == "openai/gpt-4" + assert call_kwargs["api_key"] == "test-key-123" + + @pytest.mark.asyncio + async def test_primary_llm_custom_timeout(self): + """Test primary LLM call with custom timeout.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_TIMEOUT_SECONDS": "30", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [{"role": "user", "content": "Hello"}] + await manager.call_llm(messages, fallback_mode=False) + + # Verify timeout was passed correctly + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["timeout"] == 30 + + @pytest.mark.asyncio + async def test_primary_llm_failure_raises_exception(self): + """Test that primary LLM failure raises exception.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + }, + ): + manager = LLMManager() + + # Mock completion to fail + with patch("shared.llm_manager.completion", side_effect=Exception("API Error")): + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(Exception, match="API Error"): + await manager.call_llm(messages, fallback_mode=False) + + +class TestLLMManagerFallbackLLM: + """Test cases for fallback LLM functionality.""" + + def setup_method(self): + """Set up test environment for each test.""" + env_vars_to_clear = [ + "LLM_MODEL", + "LLM_KEY", + "LLM_FALLBACK_MODEL", + "LLM_FALLBACK_KEY", + "LLM_FALLBACK_BASE_URL", + "LLM_FALLBACK_TIMEOUT_SECONDS", + ] + for var in env_vars_to_clear: + os.environ.pop(var, None) + + @pytest.mark.asyncio + async def test_fallback_llm_success(self): + """Test successful fallback LLM call when in fallback mode.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_FALLBACK_MODEL": "anthropic/claude-3-5-sonnet-20241022", + "LLM_FALLBACK_KEY": "fallback-key-456", + }, + ): + manager = LLMManager() + + # Mock the completion function + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Fallback response" + + with patch("shared.llm_manager.completion", return_value=mock_response): + messages = [{"role": "user", "content": "Hello"}] + response = await manager.call_llm(messages, fallback_mode=True) + + assert response == mock_response + + @pytest.mark.asyncio + async def test_no_fallback_configured(self): + """Test error when fallback is needed but not configured.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + }, + ): + manager = LLMManager() + + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(Exception, match="No fallback model configured"): + await manager.call_llm(messages, fallback_mode=True) + + @pytest.mark.asyncio + async def test_fallback_llm_with_custom_base_url(self): + """Test fallback LLM call with custom base URL.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_FALLBACK_MODEL": "anthropic/claude-3-5-sonnet-20241022", + "LLM_FALLBACK_KEY": "fallback-key-456", + "LLM_FALLBACK_BASE_URL": "https://custom.fallback.com", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Fallback response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [{"role": "user", "content": "Hello"}] + await manager.call_llm(messages, fallback_mode=True) + + # Verify completion was called with fallback base_url + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["base_url"] == "https://custom.fallback.com" + assert call_kwargs["model"] == "anthropic/claude-3-5-sonnet-20241022" + assert call_kwargs["api_key"] == "fallback-key-456" + + @pytest.mark.asyncio + async def test_fallback_llm_custom_timeout(self): + """Test fallback LLM call with custom timeout.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_FALLBACK_MODEL": "anthropic/claude-3-5-sonnet-20241022", + "LLM_FALLBACK_KEY": "fallback-key-456", + "LLM_FALLBACK_TIMEOUT_SECONDS": "20", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [{"role": "user", "content": "Hello"}] + await manager.call_llm(messages, fallback_mode=True) + + # Verify timeout was passed correctly + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["timeout"] == 20 + + @pytest.mark.asyncio + async def test_both_llms_fail(self): + """Test error when both primary and fallback LLMs fail.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + "LLM_FALLBACK_MODEL": "anthropic/claude-3-5-sonnet-20241022", + "LLM_FALLBACK_KEY": "fallback-key-456", + }, + ): + manager = LLMManager() + + # Mock completion to fail + with patch("shared.llm_manager.completion", side_effect=Exception("Connection failed")): + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(Exception, match="Fallback LLM failed"): + await manager.call_llm(messages, fallback_mode=True) + + +class TestLLMManagerMessageHandling: + """Test cases for message handling.""" + + def setup_method(self): + """Set up test environment for each test.""" + env_vars_to_clear = ["LLM_MODEL", "LLM_KEY"] + for var in env_vars_to_clear: + os.environ.pop(var, None) + + @pytest.mark.asyncio + async def test_message_format_single_message(self): + """Test that single message is passed correctly to the LLM.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [{"role": "user", "content": "Hello"}] + await manager.call_llm(messages, fallback_mode=False) + + # Verify messages were passed correctly + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["messages"] == messages + + @pytest.mark.asyncio + async def test_message_format_multiple_messages(self): + """Test that multiple messages are passed correctly to the LLM.""" + with patch.dict( + os.environ, + { + "LLM_MODEL": "openai/gpt-4", + "LLM_KEY": "test-key-123", + }, + ): + manager = LLMManager() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + + with patch("shared.llm_manager.completion", return_value=mock_response) as mock_completion: + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + await manager.call_llm(messages, fallback_mode=False) + + # Verify messages were passed correctly + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["messages"] == messages + assert len(call_kwargs["messages"]) == 4 \ No newline at end of file diff --git a/tests/test_mcp_integration.py b/tests/test_mcp_integration.py index b01a4eb..71084f1 100644 --- a/tests/test_mcp_integration.py +++ b/tests/test_mcp_integration.py @@ -240,12 +240,12 @@ async def test_mcp_tool_execution_flow(client: Client): async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_validate(prompt: ValidationInput) -> ValidationResult: + @activity.defn(name="agent_validate_prompt") + async def mock_validate(prompt: ValidationInput, fallback_mode: bool) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_planner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: if "planner_called" not in captured: captured["planner_called"] = True return { @@ -341,12 +341,12 @@ async def test_create_invoice_defaults_days_until_due(client: Client): async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_validate(prompt: ValidationInput) -> ValidationResult: + @activity.defn(name="agent_validate_prompt") + async def mock_validate(prompt: ValidationInput, fallback_mode: bool) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_planner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: if "planner_called" not in captured: captured["planner_called"] = True return { @@ -442,12 +442,12 @@ async def test_mcp_tool_failure_recorded(client: Client): async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_validate(prompt: ValidationInput) -> ValidationResult: + @activity.defn(name="agent_validate_prompt") + async def mock_validate(prompt: ValidationInput, fallback_mode: bool) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_planner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: return { "next": "confirm", "tool": "list_products", diff --git a/tests/test_tool_activities.py b/tests/test_tool_activities.py index 5c37068..a0f5127 100644 --- a/tests/test_tool_activities.py +++ b/tests/test_tool_activities.py @@ -28,27 +28,27 @@ def setup_method(self): self.tool_activities = ToolActivities() @pytest.mark.asyncio - async def test_agent_validatePrompt_valid_prompt( + async def test_agent_validate_prompt_valid_prompt( self, sample_agent_goal, sample_conversation_history ): - """Test agent_validatePrompt with a valid prompt.""" + """Test agent_validate_prompt with a valid prompt.""" validation_input = ValidationInput( prompt="I need help with the test tool", conversation_history=sample_conversation_history, agent_goal=sample_agent_goal, ) - # Mock the agent_toolPlanner to return a valid response + # Mock the agent_tool_planner to return a valid response mock_response = {"validationResult": True, "validationFailedReason": {}} with patch.object( - self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock + self.tool_activities, "agent_tool_planner", new_callable=AsyncMock ) as mock_planner: mock_planner.return_value = mock_response activity_env = ActivityEnvironment() result = await activity_env.run( - self.tool_activities.agent_validatePrompt, validation_input + self.tool_activities.agent_validate_prompt, validation_input, False ) assert isinstance(result, ValidationResult) @@ -59,17 +59,17 @@ async def test_agent_validatePrompt_valid_prompt( mock_planner.assert_called_once() @pytest.mark.asyncio - async def test_agent_validatePrompt_invalid_prompt( + async def test_agent_validate_prompt_invalid_prompt( self, sample_agent_goal, sample_conversation_history ): - """Test agent_validatePrompt with an invalid prompt.""" + """Test agent_validate_prompt with an invalid prompt.""" validation_input = ValidationInput( prompt="asdfghjkl nonsense", conversation_history=sample_conversation_history, agent_goal=sample_agent_goal, ) - # Mock the agent_toolPlanner to return an invalid response + # Mock the agent_tool_planner to return an invalid response mock_response = { "validationResult": False, "validationFailedReason": { @@ -79,13 +79,13 @@ async def test_agent_validatePrompt_invalid_prompt( } with patch.object( - self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock + self.tool_activities, "agent_tool_planner", new_callable=AsyncMock ) as mock_planner: mock_planner.return_value = mock_response activity_env = ActivityEnvironment() result = await activity_env.run( - self.tool_activities.agent_validatePrompt, validation_input + self.tool_activities.agent_validate_prompt, validation_input, False ) assert isinstance(result, ValidationResult) @@ -93,13 +93,13 @@ async def test_agent_validatePrompt_invalid_prompt( assert "doesn't make sense" in str(result.validationFailedReason) @pytest.mark.asyncio - async def test_agent_toolPlanner_success(self): - """Test agent_toolPlanner with successful LLM response.""" + async def test_agent_tool_planner_success(self): + """Test agent_tool_planner with successful LLM response.""" prompt_input = ToolPromptInput( prompt="Test prompt", context_instructions="Test context instructions" ) - # Mock the completion function + # Mock the llm_manager.call_llm method mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[ @@ -108,12 +108,14 @@ async def test_agent_toolPlanner_success(self): '{"next": "confirm", "tool": "TestTool", "response": "Test response"}' ) - with patch("activities.tool_activities.completion") as mock_completion: - mock_completion.return_value = mock_response + with patch.object( + self.tool_activities.llm_manager, "call_llm", new_callable=AsyncMock + ) as mock_call_llm: + mock_call_llm.return_value = mock_response activity_env = ActivityEnvironment() result = await activity_env.run( - self.tool_activities.agent_toolPlanner, prompt_input + self.tool_activities.agent_tool_planner, prompt_input, False ) assert isinstance(result, dict) @@ -121,17 +123,18 @@ async def test_agent_toolPlanner_success(self): assert result["tool"] == "TestTool" assert result["response"] == "Test response" - # Verify completion was called with correct parameters - mock_completion.assert_called_once() - call_args = mock_completion.call_args[1] - assert call_args["model"] == self.tool_activities.llm_model - assert len(call_args["messages"]) == 2 - assert call_args["messages"][0]["role"] == "system" - assert call_args["messages"][1]["role"] == "user" + # Verify call_llm was called with correct parameters + mock_call_llm.assert_called_once() + call_args = mock_call_llm.call_args[0][ + 0 + ] # First positional argument (messages) + assert len(call_args) == 2 + assert call_args[0]["role"] == "system" + assert call_args[1]["role"] == "user" @pytest.mark.asyncio - async def test_agent_toolPlanner_with_custom_base_url(self): - """Test agent_toolPlanner with custom base URL configuration.""" + async def test_agent_tool_planner_with_custom_base_url(self): + """Test agent_tool_planner with custom base URL configuration.""" # Set up tool activities with custom base URL with patch.dict(os.environ, {"LLM_BASE_URL": "https://custom.endpoint.com"}): tool_activities = ToolActivities() @@ -146,36 +149,38 @@ async def test_agent_toolPlanner_with_custom_base_url(self): 0 ].message.content = '{"next": "done", "response": "Test"}' - with patch("activities.tool_activities.completion") as mock_completion: - mock_completion.return_value = mock_response + with patch.object( + tool_activities.llm_manager, "call_llm", new_callable=AsyncMock + ) as mock_call_llm: + mock_call_llm.return_value = mock_response activity_env = ActivityEnvironment() - await activity_env.run(tool_activities.agent_toolPlanner, prompt_input) + await activity_env.run(tool_activities.agent_tool_planner, prompt_input, False) - # Verify base_url was included in the call - call_args = mock_completion.call_args[1] - assert "base_url" in call_args - assert call_args["base_url"] == "https://custom.endpoint.com" + # Verify call_llm was called + mock_call_llm.assert_called_once() @pytest.mark.asyncio - async def test_agent_toolPlanner_json_parsing_error(self): - """Test agent_toolPlanner handles JSON parsing errors.""" + async def test_agent_tool_planner_json_parsing_error(self): + """Test agent_tool_planner handles JSON parsing errors.""" prompt_input = ToolPromptInput( prompt="Test prompt", context_instructions="Test context instructions" ) - # Mock the completion function to return invalid JSON + # Mock the llm_manager.call_llm method to return invalid JSON mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Invalid JSON response" - with patch("activities.tool_activities.completion") as mock_completion: - mock_completion.return_value = mock_response + with patch.object( + self.tool_activities.llm_manager, "call_llm", new_callable=AsyncMock + ) as mock_call_llm: + mock_call_llm.return_value = mock_response activity_env = ActivityEnvironment() with pytest.raises(Exception): # Should raise JSON parsing error await activity_env.run( - self.tool_activities.agent_toolPlanner, prompt_input + self.tool_activities.agent_tool_planner, prompt_input, False ) @pytest.mark.asyncio @@ -331,7 +336,7 @@ def setup_method(self): self.tool_activities = ToolActivities() @pytest.mark.asyncio - async def test_agent_validatePrompt_with_empty_conversation_history( + async def test_agent_validate_prompt_with_empty_conversation_history( self, sample_agent_goal ): """Test validation with empty conversation history.""" @@ -344,13 +349,13 @@ async def test_agent_validatePrompt_with_empty_conversation_history( mock_response = {"validationResult": True, "validationFailedReason": {}} with patch.object( - self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock + self.tool_activities, "agent_tool_planner", new_callable=AsyncMock ) as mock_planner: mock_planner.return_value = mock_response activity_env = ActivityEnvironment() result = await activity_env.run( - self.tool_activities.agent_validatePrompt, validation_input + self.tool_activities.agent_validate_prompt, validation_input, False ) assert isinstance(result, ValidationResult) @@ -358,7 +363,7 @@ async def test_agent_validatePrompt_with_empty_conversation_history( assert result.validationFailedReason == {} @pytest.mark.asyncio - async def test_agent_toolPlanner_with_long_prompt(self): + async def test_agent_tool_planner_with_long_prompt(self): """Test toolPlanner with very long prompt.""" long_prompt = "This is a very long prompt " * 100 tool_prompt_input = ToolPromptInput( @@ -372,10 +377,13 @@ async def test_agent_toolPlanner_with_long_prompt(self): 0 ].message.content = '{"next": "done", "response": "Processed long prompt"}' - with patch("activities.tool_activities.completion", return_value=mock_response): + with patch.object( + self.tool_activities.llm_manager, "call_llm", new_callable=AsyncMock + ) as mock_call_llm: + mock_call_llm.return_value = mock_response activity_env = ActivityEnvironment() result = await activity_env.run( - self.tool_activities.agent_toolPlanner, tool_prompt_input + self.tool_activities.agent_tool_planner, tool_prompt_input, False ) assert isinstance(result, dict) diff --git a/tests/workflowtests/agent_goal_workflow_test.py b/tests/workflowtests/agent_goal_workflow_test.py index 6d08c44..50b7ff8 100644 --- a/tests/workflowtests/agent_goal_workflow_test.py +++ b/tests/workflowtests/agent_goal_workflow_test.py @@ -46,14 +46,15 @@ async def test_flight_booking(client: Client): async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput: return EnvLookupOutput(show_confirm=True, multi_goal_mode=True) - @activity.defn(name="agent_validatePrompt") - async def mock_agent_validatePrompt( + @activity.defn(name="agent_validate_prompt") + async def mock_agent_validate_prompt( validation_input: ValidationInput, + fallback_mode: bool, ) -> ValidationResult: return ValidationResult(validationResult=True, validationFailedReason={}) - @activity.defn(name="agent_toolPlanner") - async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict: + @activity.defn(name="agent_tool_planner") + async def mock_agent_tool_planner(input: ToolPromptInput, fallback_mode: bool) -> dict: return {"next": "done", "response": "Test response from LLM"} @activity.defn(name="mcp_list_tools") @@ -82,8 +83,8 @@ async def mock_dynamic_tool_activity(args: Sequence[RawValue]) -> dict: workflows=[AgentGoalWorkflow], activities=[ mock_get_wf_env_vars, - mock_agent_validatePrompt, - mock_agent_toolPlanner, + mock_agent_validate_prompt, + mock_agent_tool_planner, mock_mcp_list_tools, mock_mcp_tool_activity, mock_dynamic_tool_activity, @@ -101,7 +102,7 @@ async def mock_dynamic_tool_activity(args: Sequence[RawValue]) -> dict: prompt = "Hello!" - # async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, dynamic_tool_activity]): + # async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validate_prompt, ToolActivities.agent_tool_planner, dynamic_tool_activity]): # todo set goal categories for scenarios handle = await client.start_workflow( diff --git a/workflows/agent_goal_workflow.py b/workflows/agent_goal_workflow.py index 1f8b801..751b6db 100644 --- a/workflows/agent_goal_workflow.py +++ b/workflows/agent_goal_workflow.py @@ -4,13 +4,14 @@ from temporalio import workflow from temporalio.common import RetryPolicy +from temporalio.exceptions import ActivityError from models.data_types import ( ConversationHistory, EnvLookupInput, EnvLookupOutput, NextStep, - ValidationInput, + ValidationInput, ValidationResult, ) from models.tool_definitions import AgentGoal from workflows import workflow_helpers as helpers @@ -44,6 +45,7 @@ class AgentGoalWorkflow: """Workflow that manages tool execution with user confirmation and conversation history.""" def __init__(self) -> None: + self.fallback_mode = False # Fallback mode indicates the fallback LLM should be used. self.conversation_history: ConversationHistory = {"messages": []} self.prompt_queue: Deque[str] = deque() self.conversation_summary: Optional[str] = None @@ -118,23 +120,7 @@ async def run(self, combined_input: CombinedInput) -> str: # Validate user-provided prompts if self.is_user_prompt(prompt): - self.add_message("user", prompt) - - # Validate the prompt before proceeding - validation_input = ValidationInput( - prompt=prompt, - conversation_history=self.conversation_history, - agent_goal=self.goal, - ) - validation_result = await workflow.execute_activity_method( - ToolActivities.agent_validatePrompt, - args=[validation_input], - schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, - start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=5), backoff_coefficient=1 - ), - ) + validation_result = await self._validate_prompt(prompt) # If validation fails, provide that feedback to the user - i.e., "your words make no sense, puny human" end this iteration of processing if not validation_result.validationResult: @@ -160,15 +146,7 @@ async def run(self, combined_input: CombinedInput) -> str: ) # connect to LLM and execute to get next steps - tool_data = await workflow.execute_activity_method( - ToolActivities.agent_toolPlanner, - prompt_input, - schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, - start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=5), backoff_coefficient=1 - ), - ) + tool_data = await self._execute_prompt(prompt_input) tool_data["force_confirm"] = self.show_tool_args_confirmation self.tool_data = tool_data @@ -222,6 +200,82 @@ async def run(self, combined_input: CombinedInput) -> str: self.add_message, ) + async def _execute_prompt(self, prompt_input: ToolPromptInput) -> dict: + summary = "fallback" if self.fallback_mode else "" + try: + tool_data = await workflow.execute_activity_method( + ToolActivities.agent_tool_planner, + args=[prompt_input, self.fallback_mode], + schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + backoff_coefficient=1, + maximum_attempts=2 + ), + summary=summary + ) + except ActivityError as ae: + workflow.logger.info( + f"Tool planner failed 2 times, switching to fallback mode" + ) + self.fallback_mode = True + tool_data = await workflow.execute_activity_method( + ToolActivities.agent_tool_planner, + args=[prompt_input, self.fallback_mode], + schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + backoff_coefficient=1, + maximum_attempts=2 + ), + summary='fallback' + ) + return tool_data + + async def _validate_prompt(self, prompt: str) -> ValidationResult: + self.add_message("user", prompt) + + # Validate the prompt before proceeding + validation_input = ValidationInput( + prompt=prompt, + conversation_history=self.conversation_history, + agent_goal=self.goal, + ) + try: + summary = "fallback" if self.fallback_mode else "" + validation_result = await workflow.execute_activity_method( + ToolActivities.agent_validate_prompt, + args=[validation_input, self.fallback_mode], + schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + backoff_coefficient=1, + maximum_attempts=2 + ), + summary=summary + ) + except ActivityError as ae: + workflow.logger.info( + f"Validate prompt failed 2 times, switching to fallback mode" + ) + self.fallback_mode = True + validation_result = await workflow.execute_activity_method( + ToolActivities.agent_validate_prompt, + args=[validation_input, self.fallback_mode], + schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + start_to_close_timeout=LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + backoff_coefficient=1, + maximum_attempts=2 + ), + summary="fallback" + ) + return validation_result + # Signal that comes from api/main.py via a post to /send-prompt @workflow.signal async def user_prompt(self, prompt: str) -> None: diff --git a/workflows/workflow_helpers.py b/workflows/workflow_helpers.py index fb066b9..6032820 100644 --- a/workflows/workflow_helpers.py +++ b/workflows/workflow_helpers.py @@ -158,7 +158,7 @@ async def continue_as_new_if_needed( prompt=summary_prompt, context_instructions=summary_context ) conversation_summary = await workflow.start_activity_method( - "ToolActivities.agent_toolPlanner", + "ToolActivities.agent_tool_planner", summary_input, schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, ) From 4b67bbbae9fe4bd1261b675ac40335462dd1ec12 Mon Sep 17 00:00:00 2001 From: Steve Wall Date: Mon, 6 Oct 2025 16:18:14 -0600 Subject: [PATCH 2/2] wip: Add fallback LLM if primary fails. --- docs/setup.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/setup.md b/docs/setup.md index 3270d05..2c7487f 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -142,7 +142,7 @@ For a complete list of supported models and providers, visit the [LiteLLM docume #### Fallback LLM Configuration -The system includes automatic fallback functionality to improve reliability when the primary LLM becomes unavailable. The LLM Manager provides transparent failover with automatic recovery detection. +The system includes automatic fallback functionality to improve reliability when the primary LLM becomes unavailable. Configure fallback LLM settings in your `.env` file: