diff --git a/openevolve/llm/openai.py b/openevolve/llm/openai.py index 7946b4d8..50eacbad 100644 --- a/openevolve/llm/openai.py +++ b/openevolve/llm/openai.py @@ -1,8 +1,15 @@ """ OpenAI API interface for LLMs """ - import asyncio + +try: + asyncio.get_running_loop() +except RuntimeError: + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + import logging import time from typing import Any, Dict, List, Optional, Union @@ -14,6 +21,10 @@ logger = logging.getLogger(__name__) +_O_SERIES_MODELS = {"o1", "o1-mini", "o1-pro" + "o-3", "o3-mini", "o3-pro", + "o4-mini"} + class OpenAILLM(LLMInterface): """LLM interface using OpenAI-compatible APIs""" @@ -64,22 +75,50 @@ async def generate_with_context( formatted_messages = [{"role": "system", "content": system_message}] formatted_messages.extend(messages) + # define params + params: Dict[str, Any] = { + "model": self.model, + "messages": formatted_messages, + } + # Set up generation parameters - if self.api_base == "https://api.openai.com/v1" and str(self.model).lower().startswith("o"): - # For o-series models - params = { - "model": self.model, - "messages": formatted_messages, - "max_completion_tokens": kwargs.get("max_tokens", self.max_tokens), - } + # if self.api_base == "https://api.openai.com/v1" and str(self.model).lower().startswith("o"): + # # For o-series models + # params = { + # "model": self.model, + # "messages": formatted_messages, + # "max_completion_tokens": kwargs.get("max_tokens", self.max_tokens), + # } + # else: + # params = { + # "model": self.model, + # "messages": formatted_messages, + # "temperature": kwargs.get("temperature", self.temperature), + # "top_p": kwargs.get("top_p", self.top_p), + # "max_tokens": kwargs.get("max_tokens", self.max_tokens), + # } + + if self.api_base == "https://api.openai.com/v1": + params["max_completion_tokens"] = kwargs.get( + "max_tokens", self.max_tokens) else: - params = { - "model": self.model, - "messages": formatted_messages, - "temperature": kwargs.get("temperature", self.temperature), - "top_p": kwargs.get("top_p", self.top_p), - "max_tokens": kwargs.get("max_tokens", self.max_tokens), - } + params["max_tokens"] = kwargs.get("max_tokens", self.max_tokens) + + get_model = str(self.model).lower() + if self.api_base == "https://api.openai.com/v1" and get_model in _O_SERIES_MODELS: + # if user sets up temperature in config, will have a warning + if self.temperature is not None: + logger.warning( + f"Model {self.model!r} doesn't support temperature" + ) + kwargs.pop("temperature", None) + kwargs.pop("top_p", None) + + else: + params["temperature"] = kwargs.get("temperature", self.temperature) + params["top_p"] = kwargs.get("top_p", self.top_p) + + print("[DEBUG] LLM params:", params.keys()) # Add seed parameter for reproducibility if configured # Skip seed parameter for Google AI Studio endpoint as it doesn't support it @@ -104,10 +143,12 @@ async def generate_with_context( return response except asyncio.TimeoutError: if attempt < retries: - logger.warning(f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...") + logger.warning( + f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...") await asyncio.sleep(retry_delay) else: - logger.error(f"All {retries + 1} attempts failed with timeout") + logger.error( + f"All {retries + 1} attempts failed with timeout") raise except Exception as e: if attempt < retries: @@ -116,7 +157,8 @@ async def generate_with_context( ) await asyncio.sleep(retry_delay) else: - logger.error(f"All {retries + 1} attempts failed with error: {str(e)}") + logger.error( + f"All {retries + 1} attempts failed with error: {str(e)}") raise async def _call_api(self, params: Dict[str, Any]) -> str: diff --git a/tests/test_openai_model.py b/tests/test_openai_model.py new file mode 100644 index 00000000..b3c7166a --- /dev/null +++ b/tests/test_openai_model.py @@ -0,0 +1,75 @@ + +""" +Tests for O series model config check +""" +import asyncio +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from openevolve.llm.openai import OpenAILLM + + +class TestOpenAILLM(unittest.TestCase): + + def setUp(self): + self.model_cfg = SimpleNamespace( + name="test-model", + system_message="SYS", + temperature=0.7, + top_p=0.98, + max_tokens=42, + timeout=1, + retries=0, + retry_delay=0, + api_base="https://api.openai.com/v1", + api_key="fake-key", + random_seed=123, + ) + + fake_choice = SimpleNamespace(message=SimpleNamespace(content=" OK")) + fake_response = SimpleNamespace(choices=[fake_choice]) + + self.fake_client = MagicMock() + self.fake_client.chat.completions.create.return_value = fake_response + + def test_generate_happy_path(self): + + with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client) as _: + llm = OpenAILLM(self.model_cfg) + + result = asyncio.get_event_loop().run_until_complete( + llm.generate("hello world") + ) + + self.assertEqual(result, " OK") + + called_kwargs = self.fake_client.chat.completions.create.call_args.kwargs + msgs = called_kwargs["messages"] + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[0]["content"], "SYS") + self.assertEqual(msgs[1]["role"], "user") + self.assertEqual(msgs[1]["content"], "hello world") + + def test_max_completion_tokens_branch(self): + self.model_cfg.name = "o4-mini" + with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client): + llm = OpenAILLM(self.model_cfg) + asyncio.get_event_loop().run_until_complete(llm.generate("foo")) + + called = self.fake_client.chat.completions.create.call_args.kwargs + + self.assertIn("max_completion_tokens", called) + self.assertNotIn("max_tokens", called) + + def test_fallback_max_tokens_branch(self): + + self.model_cfg.api_base = "https://my.custom.endpoint" + with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client): + llm = OpenAILLM(self.model_cfg) + asyncio.get_event_loop().run_until_complete(llm.generate("bar")) + + called = self.fake_client.chat.completions.create.call_args.kwargs + + self.assertIn("max_tokens", called) + self.assertNotIn("max_completion_tokens", called)