diff --git a/openevolve/config.py b/openevolve/config.py index b04dc7c7..30ad5436 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -36,6 +36,20 @@ class LLMConfig: retries: int = 3 retry_delay: int = 5 + def __post_init__(self): + """Set up API key from environment if not provided""" + if not self.api_key: + # Try to get API key from environment + if self.primary_model.startswith("claude-"): + self.api_key = os.environ.get("ANTHROPIC_API_KEY") + else: + self.api_key = os.environ.get("OPENAI_API_KEY") + + # Set default API base based on model type + if self.api_base == "https://api.openai.com/v1": + if self.primary_model.startswith("claude-"): + self.api_base = "https://api.anthropic.com/v1" + @dataclass class PromptConfig: diff --git a/openevolve/llm/__init__.py b/openevolve/llm/__init__.py index 26bbef56..fb95f6eb 100644 --- a/openevolve/llm/__init__.py +++ b/openevolve/llm/__init__.py @@ -5,5 +5,6 @@ from openevolve.llm.base import LLMInterface from openevolve.llm.ensemble import LLMEnsemble from openevolve.llm.openai import OpenAILLM +from openevolve.llm.anthropic import AnthropicLLM -__all__ = ["LLMInterface", "OpenAILLM", "LLMEnsemble"] +__all__ = ["LLMInterface", "OpenAILLM", "AnthropicLLM", "LLMEnsemble"] diff --git a/openevolve/llm/anthropic.py b/openevolve/llm/anthropic.py new file mode 100644 index 00000000..145a2e15 --- /dev/null +++ b/openevolve/llm/anthropic.py @@ -0,0 +1,96 @@ +""" +Anthropic Claude API interface for LLMs +""" + +import asyncio +import logging +from typing import Any, Dict, List, Optional + +import anthropic + +from openevolve.config import LLMConfig +from openevolve.llm.base import LLMInterface + +logger = logging.getLogger(__name__) + + +class AnthropicLLM(LLMInterface): + """LLM interface using Anthropic's Claude API""" + + def __init__( + self, + config: LLMConfig, + model: Optional[str] = None, + ): + self.config = config + self.model = model or config.primary_model + + # Set up API client + self.client = anthropic.Anthropic( + api_key=config.api_key, + base_url=config.api_base, + ) + + logger.info(f"Initialized Anthropic LLM with model: {self.model}") + + async def generate(self, prompt: str, **kwargs) -> str: + """Generate text from a prompt""" + return await self.generate_with_context( + system_message=self.config.system_message, + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + + async def generate_with_context( + self, system_message: str, messages: List[Dict[str, str]], **kwargs + ) -> str: + """Generate text using a system message and conversational context""" + # Prepare messages for Claude format + formatted_messages = [] + for msg in messages: + formatted_messages.append({"role": msg["role"], "content": msg["content"]}) + + # Set up generation parameters + params = { + "model": self.model, + "system": system_message, + "messages": formatted_messages, + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "temperature": kwargs.get("temperature", self.config.temperature), + "top_p": kwargs.get("top_p", self.config.top_p), + } + + # Attempt the API call with retries + retries = kwargs.get("retries", self.config.retries) + retry_delay = kwargs.get("retry_delay", self.config.retry_delay) + timeout = kwargs.get("timeout", self.config.timeout) + + for attempt in range(retries + 1): + try: + response = await asyncio.wait_for(self._call_api(params), timeout=timeout) + return response + except asyncio.TimeoutError: + if attempt < retries: + logger.warning(f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...") + await asyncio.sleep(retry_delay) + else: + logger.error(f"All {retries + 1} attempts failed with timeout") + raise + except Exception as e: + if attempt < retries: + logger.warning( + f"Error on attempt {attempt + 1}/{retries + 1}: {str(e)}. Retrying..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error(f"All {retries + 1} attempts failed with error: {str(e)}") + raise + + async def _call_api(self, params: Dict[str, Any]) -> str: + """Make the actual API call""" + # Use asyncio to run the blocking API call in a thread pool + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, lambda: self.client.messages.create(**params)) + + # Extract the response content + return response.content[0].text diff --git a/openevolve/llm/ensemble.py b/openevolve/llm/ensemble.py index 0c518cca..3907d2b3 100644 --- a/openevolve/llm/ensemble.py +++ b/openevolve/llm/ensemble.py @@ -10,10 +10,19 @@ from openevolve.config import LLMConfig from openevolve.llm.base import LLMInterface from openevolve.llm.openai import OpenAILLM +from openevolve.llm.anthropic import AnthropicLLM logger = logging.getLogger(__name__) +def create_llm(config: LLMConfig, model: str) -> LLMInterface: + """Create an LLM instance based on the model name""" + if model.startswith("claude-") or model.startswith("anthropic/"): + return AnthropicLLM(config, model=model) + else: + return OpenAILLM(config, model=model) + + class LLMEnsemble: """Ensemble of LLMs for generating diverse code modifications""" @@ -21,8 +30,8 @@ def __init__(self, config: LLMConfig): self.config = config # Initialize primary and secondary models - self.primary_model = OpenAILLM(config, model=config.primary_model) - self.secondary_model = OpenAILLM(config, model=config.secondary_model) + self.primary_model = create_llm(config, config.primary_model) + self.secondary_model = create_llm(config, config.secondary_model) # Model weights for sampling self._weights = [ diff --git a/pyproject.toml b/pyproject.toml index 24ed10ba..5ac01802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ authors = [ ] dependencies = [ "openai>=1.0.0", + "anthropic>=0.8.0", "pyyaml>=6.0", "numpy>=1.22.0", "tqdm>=4.64.0", diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 00000000..dbdb5126 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,145 @@ +""" +Tests for LLM implementations +""" + +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from openevolve.config import LLMConfig +from openevolve.llm.anthropic import AnthropicLLM +from openevolve.llm.openai import OpenAILLM + + +class TestLLMImplementations(unittest.TestCase): + """Tests for LLM implementations""" + + def setUp(self): + """Set up test configuration""" + self.config = LLMConfig( + primary_model="test-model", + api_key="test-key", + api_base="https://test.api", + ) + + @patch("anthropic.Anthropic") + async def test_anthropic_llm_generate(self, mock_anthropic): + """Test Anthropic LLM generate method""" + # Set up mock response + mock_response = MagicMock() + mock_response.content = [MagicMock(text="Test response")] + mock_anthropic.return_value.messages.create.return_value = mock_response + + # Create LLM instance + llm = AnthropicLLM(self.config) + + # Test generate + response = await llm.generate("Test prompt") + self.assertEqual(response, "Test response") + + # Verify API call + mock_anthropic.return_value.messages.create.assert_called_once() + call_args = mock_anthropic.return_value.messages.create.call_args[1] + self.assertEqual(call_args["model"], "test-model") + self.assertEqual(call_args["messages"][0]["role"], "user") + self.assertEqual(call_args["messages"][0]["content"], "Test prompt") + + @patch("anthropic.Anthropic") + async def test_anthropic_llm_generate_with_context(self, mock_anthropic): + """Test Anthropic LLM generate_with_context method""" + # Set up mock response + mock_response = MagicMock() + mock_response.content = [MagicMock(text="Test response")] + mock_anthropic.return_value.messages.create.return_value = mock_response + + # Create LLM instance + llm = AnthropicLLM(self.config) + + # Test generate_with_context + messages = [ + {"role": "user", "content": "Test message 1"}, + {"role": "assistant", "content": "Test response 1"}, + {"role": "user", "content": "Test message 2"}, + ] + response = await llm.generate_with_context("Test system", messages) + self.assertEqual(response, "Test response") + + # Verify API call + mock_anthropic.return_value.messages.create.assert_called_once() + call_args = mock_anthropic.return_value.messages.create.call_args[1] + self.assertEqual(call_args["model"], "test-model") + self.assertEqual(call_args["system"], "Test system") + self.assertEqual(len(call_args["messages"]), 3) + self.assertEqual(call_args["messages"][0]["role"], "user") + self.assertEqual(call_args["messages"][0]["content"], "Test message 1") + + @patch("openai.OpenAI") + async def test_openai_llm_generate(self, mock_openai): + """Test OpenAI LLM generate method""" + # Set up mock response + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))] + mock_openai.return_value.chat.completions.create.return_value = mock_response + + # Create LLM instance + llm = OpenAILLM(self.config) + + # Test generate + response = await llm.generate("Test prompt") + self.assertEqual(response, "Test response") + + # Verify API call + mock_openai.return_value.chat.completions.create.assert_called_once() + call_args = mock_openai.return_value.chat.completions.create.call_args[1] + self.assertEqual(call_args["model"], "test-model") + self.assertEqual(call_args["messages"][0]["role"], "user") + self.assertEqual(call_args["messages"][0]["content"], "Test prompt") + + @patch("openai.OpenAI") + async def test_openai_llm_generate_with_context(self, mock_openai): + """Test OpenAI LLM generate_with_context method""" + # Set up mock response + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))] + mock_openai.return_value.chat.completions.create.return_value = mock_response + + # Create LLM instance + llm = OpenAILLM(self.config) + + # Test generate_with_context + messages = [ + {"role": "user", "content": "Test message 1"}, + {"role": "assistant", "content": "Test response 1"}, + {"role": "user", "content": "Test message 2"}, + ] + response = await llm.generate_with_context("Test system", messages) + self.assertEqual(response, "Test response") + + # Verify API call + mock_openai.return_value.chat.completions.create.assert_called_once() + call_args = mock_openai.return_value.chat.completions.create.call_args[1] + self.assertEqual(call_args["model"], "test-model") + self.assertEqual(call_args["messages"][0]["role"], "system") + self.assertEqual(call_args["messages"][0]["content"], "Test system") + self.assertEqual(len(call_args["messages"]), 4) # system + 3 messages + + def test_llm_config_model_detection(self): + """Test LLM configuration model type detection""" + # Test OpenAI model + config = LLMConfig(primary_model="gpt-4") + self.assertEqual(config.api_base, "https://api.openai.com/v1") + + # Test Claude model + config = LLMConfig(primary_model="claude-3-sonnet") + self.assertEqual(config.api_base, "https://api.anthropic.com/v1") + + # Test custom API base + config = LLMConfig( + primary_model="claude-3-sonnet", + api_base="https://custom.api", + ) + self.assertEqual(config.api_base, "https://custom.api") + + +if __name__ == "__main__": + unittest.main()