Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions openevolve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class LLMConfig:
retries: int = 3
retry_delay: int = 5

def __post_init__(self):
"""Set up API key from environment if not provided"""
if not self.api_key:
# Try to get API key from environment
if self.primary_model.startswith("claude-"):
self.api_key = os.environ.get("ANTHROPIC_API_KEY")
else:
self.api_key = os.environ.get("OPENAI_API_KEY")

# Set default API base based on model type
if self.api_base == "https://api.openai.com/v1":
if self.primary_model.startswith("claude-"):
self.api_base = "https://api.anthropic.com/v1"


@dataclass
class PromptConfig:
Expand Down
3 changes: 2 additions & 1 deletion openevolve/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
from openevolve.llm.base import LLMInterface
from openevolve.llm.ensemble import LLMEnsemble
from openevolve.llm.openai import OpenAILLM
from openevolve.llm.anthropic import AnthropicLLM

__all__ = ["LLMInterface", "OpenAILLM", "LLMEnsemble"]
__all__ = ["LLMInterface", "OpenAILLM", "AnthropicLLM", "LLMEnsemble"]
96 changes: 96 additions & 0 deletions openevolve/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Anthropic Claude API interface for LLMs
"""

import asyncio
import logging
from typing import Any, Dict, List, Optional

import anthropic

from openevolve.config import LLMConfig
from openevolve.llm.base import LLMInterface

logger = logging.getLogger(__name__)


class AnthropicLLM(LLMInterface):
"""LLM interface using Anthropic's Claude API"""

def __init__(
self,
config: LLMConfig,
model: Optional[str] = None,
):
self.config = config
self.model = model or config.primary_model

# Set up API client
self.client = anthropic.Anthropic(
api_key=config.api_key,
base_url=config.api_base,
)

logger.info(f"Initialized Anthropic LLM with model: {self.model}")

async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text from a prompt"""
return await self.generate_with_context(
system_message=self.config.system_message,
messages=[{"role": "user", "content": prompt}],
**kwargs,
)

async def generate_with_context(
self, system_message: str, messages: List[Dict[str, str]], **kwargs
) -> str:
"""Generate text using a system message and conversational context"""
# Prepare messages for Claude format
formatted_messages = []
for msg in messages:
formatted_messages.append({"role": msg["role"], "content": msg["content"]})

# Set up generation parameters
params = {
"model": self.model,
"system": system_message,
"messages": formatted_messages,
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
}

# Attempt the API call with retries
retries = kwargs.get("retries", self.config.retries)
retry_delay = kwargs.get("retry_delay", self.config.retry_delay)
timeout = kwargs.get("timeout", self.config.timeout)

for attempt in range(retries + 1):
try:
response = await asyncio.wait_for(self._call_api(params), timeout=timeout)
return response
except asyncio.TimeoutError:
if attempt < retries:
logger.warning(f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...")
await asyncio.sleep(retry_delay)
else:
logger.error(f"All {retries + 1} attempts failed with timeout")
raise
except Exception as e:
if attempt < retries:
logger.warning(
f"Error on attempt {attempt + 1}/{retries + 1}: {str(e)}. Retrying..."
)
await asyncio.sleep(retry_delay)
else:
logger.error(f"All {retries + 1} attempts failed with error: {str(e)}")
raise

async def _call_api(self, params: Dict[str, Any]) -> str:
"""Make the actual API call"""
# Use asyncio to run the blocking API call in a thread pool
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, lambda: self.client.messages.create(**params))

# Extract the response content
return response.content[0].text
13 changes: 11 additions & 2 deletions openevolve/llm/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,28 @@
from openevolve.config import LLMConfig
from openevolve.llm.base import LLMInterface
from openevolve.llm.openai import OpenAILLM
from openevolve.llm.anthropic import AnthropicLLM

logger = logging.getLogger(__name__)


def create_llm(config: LLMConfig, model: str) -> LLMInterface:
"""Create an LLM instance based on the model name"""
if model.startswith("claude-") or model.startswith("anthropic/"):
return AnthropicLLM(config, model=model)
else:
return OpenAILLM(config, model=model)


class LLMEnsemble:
"""Ensemble of LLMs for generating diverse code modifications"""

def __init__(self, config: LLMConfig):
self.config = config

# Initialize primary and secondary models
self.primary_model = OpenAILLM(config, model=config.primary_model)
self.secondary_model = OpenAILLM(config, model=config.secondary_model)
self.primary_model = create_llm(config, config.primary_model)
self.secondary_model = create_llm(config, config.secondary_model)

# Model weights for sampling
self._weights = [
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ authors = [
]
dependencies = [
"openai>=1.0.0",
"anthropic>=0.8.0",
"pyyaml>=6.0",
"numpy>=1.22.0",
"tqdm>=4.64.0",
Expand Down
145 changes: 145 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Tests for LLM implementations
"""

import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch

from openevolve.config import LLMConfig
from openevolve.llm.anthropic import AnthropicLLM
from openevolve.llm.openai import OpenAILLM


class TestLLMImplementations(unittest.TestCase):
"""Tests for LLM implementations"""

def setUp(self):
"""Set up test configuration"""
self.config = LLMConfig(
primary_model="test-model",
api_key="test-key",
api_base="https://test.api",
)

@patch("anthropic.Anthropic")
async def test_anthropic_llm_generate(self, mock_anthropic):
"""Test Anthropic LLM generate method"""
# Set up mock response
mock_response = MagicMock()
mock_response.content = [MagicMock(text="Test response")]
mock_anthropic.return_value.messages.create.return_value = mock_response

# Create LLM instance
llm = AnthropicLLM(self.config)

# Test generate
response = await llm.generate("Test prompt")
self.assertEqual(response, "Test response")

# Verify API call
mock_anthropic.return_value.messages.create.assert_called_once()
call_args = mock_anthropic.return_value.messages.create.call_args[1]
self.assertEqual(call_args["model"], "test-model")
self.assertEqual(call_args["messages"][0]["role"], "user")
self.assertEqual(call_args["messages"][0]["content"], "Test prompt")

@patch("anthropic.Anthropic")
async def test_anthropic_llm_generate_with_context(self, mock_anthropic):
"""Test Anthropic LLM generate_with_context method"""
# Set up mock response
mock_response = MagicMock()
mock_response.content = [MagicMock(text="Test response")]
mock_anthropic.return_value.messages.create.return_value = mock_response

# Create LLM instance
llm = AnthropicLLM(self.config)

# Test generate_with_context
messages = [
{"role": "user", "content": "Test message 1"},
{"role": "assistant", "content": "Test response 1"},
{"role": "user", "content": "Test message 2"},
]
response = await llm.generate_with_context("Test system", messages)
self.assertEqual(response, "Test response")

# Verify API call
mock_anthropic.return_value.messages.create.assert_called_once()
call_args = mock_anthropic.return_value.messages.create.call_args[1]
self.assertEqual(call_args["model"], "test-model")
self.assertEqual(call_args["system"], "Test system")
self.assertEqual(len(call_args["messages"]), 3)
self.assertEqual(call_args["messages"][0]["role"], "user")
self.assertEqual(call_args["messages"][0]["content"], "Test message 1")

@patch("openai.OpenAI")
async def test_openai_llm_generate(self, mock_openai):
"""Test OpenAI LLM generate method"""
# Set up mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
mock_openai.return_value.chat.completions.create.return_value = mock_response

# Create LLM instance
llm = OpenAILLM(self.config)

# Test generate
response = await llm.generate("Test prompt")
self.assertEqual(response, "Test response")

# Verify API call
mock_openai.return_value.chat.completions.create.assert_called_once()
call_args = mock_openai.return_value.chat.completions.create.call_args[1]
self.assertEqual(call_args["model"], "test-model")
self.assertEqual(call_args["messages"][0]["role"], "user")
self.assertEqual(call_args["messages"][0]["content"], "Test prompt")

@patch("openai.OpenAI")
async def test_openai_llm_generate_with_context(self, mock_openai):
"""Test OpenAI LLM generate_with_context method"""
# Set up mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
mock_openai.return_value.chat.completions.create.return_value = mock_response

# Create LLM instance
llm = OpenAILLM(self.config)

# Test generate_with_context
messages = [
{"role": "user", "content": "Test message 1"},
{"role": "assistant", "content": "Test response 1"},
{"role": "user", "content": "Test message 2"},
]
response = await llm.generate_with_context("Test system", messages)
self.assertEqual(response, "Test response")

# Verify API call
mock_openai.return_value.chat.completions.create.assert_called_once()
call_args = mock_openai.return_value.chat.completions.create.call_args[1]
self.assertEqual(call_args["model"], "test-model")
self.assertEqual(call_args["messages"][0]["role"], "system")
self.assertEqual(call_args["messages"][0]["content"], "Test system")
self.assertEqual(len(call_args["messages"]), 4) # system + 3 messages

def test_llm_config_model_detection(self):
"""Test LLM configuration model type detection"""
# Test OpenAI model
config = LLMConfig(primary_model="gpt-4")
self.assertEqual(config.api_base, "https://api.openai.com/v1")

# Test Claude model
config = LLMConfig(primary_model="claude-3-sonnet")
self.assertEqual(config.api_base, "https://api.anthropic.com/v1")

# Test custom API base
config = LLMConfig(
primary_model="claude-3-sonnet",
api_base="https://custom.api",
)
self.assertEqual(config.api_base, "https://custom.api")


if __name__ == "__main__":
unittest.main()