diff --git a/src/utils/utils.py b/src/utils/utils.py index 6fae05bb..7cb59f9f 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -11,6 +11,7 @@ from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI import gradio as gr +from langchain_core.rate_limiters import InMemoryRateLimiter from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama @@ -31,6 +32,15 @@ def get_llm_model(provider: str, **kwargs): :param kwargs: :return: """ + rate_limit_rps = kwargs.get("rate_limit_rps", 1.0) + rate_limit_bucket = kwargs.get("rate_limit_bucket", 10) + # Create rate limiter + rate_limiter = InMemoryRateLimiter( + requests_per_second=rate_limit_rps, + check_every_n_seconds=0.1, + max_bucket_size=rate_limit_bucket + ) + if provider not in ["ollama"]: env_var = f"{provider.upper()}_API_KEY" api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") @@ -49,6 +59,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == 'mistral': if not kwargs.get("base_url", ""): @@ -65,6 +76,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "openai": if not kwargs.get("base_url", ""): @@ -77,6 +89,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "deepseek": if not kwargs.get("base_url", ""): @@ -90,19 +103,24 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) else: + return ChatOpenAI( model=kwargs.get("model_name", "deepseek-chat"), temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "google": + return ChatGoogleGenerativeAI( model=kwargs.get("model_name", "gemini-2.0-flash-exp"), temperature=kwargs.get("temperature", 0.0), google_api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "ollama": if not kwargs.get("base_url", ""): @@ -111,19 +129,23 @@ def get_llm_model(provider: str, **kwargs): base_url = kwargs.get("base_url") if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): + return DeepSeekR1ChatOllama( model=kwargs.get("model_name", "deepseek-r1:14b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), base_url=base_url, + rate_limiter=rate_limiter, ) else: + return ChatOllama( model=kwargs.get("model_name", "qwen2.5:7b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), num_predict=kwargs.get("num_predict", 1024), base_url=base_url, + rate_limiter=rate_limiter, ) elif provider == "azure_openai": if not kwargs.get("base_url", ""): @@ -137,6 +159,7 @@ def get_llm_model(provider: str, **kwargs): api_version=api_version, azure_endpoint=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "alibaba": if not kwargs.get("base_url", ""): diff --git a/webui.py b/webui.py index f61da5bc..db8ce940 100644 --- a/webui.py +++ b/webui.py @@ -119,7 +119,9 @@ async def run_browser_agent( max_steps, use_vision, max_actions_per_step, - tool_calling_method + tool_calling_method, + rate_limit_rps, + rate_limit_bucket ): global _global_agent_state _global_agent_state.clear_stop() # Clear any previous stop requests @@ -149,6 +151,8 @@ async def run_browser_agent( temperature=llm_temperature, base_url=llm_base_url, api_key=llm_api_key, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) if agent_type == "org": final_result, errors, model_actions, model_thoughts, trace_file, history_file = await run_org_agent( @@ -456,7 +460,9 @@ async def run_with_stream( max_steps, use_vision, max_actions_per_step, - tool_calling_method + tool_calling_method, + rate_limit_rps, + rate_limit_bucket ): global _global_agent_state stream_vw = 80 @@ -485,7 +491,9 @@ async def run_with_stream( max_steps=max_steps, use_vision=use_vision, max_actions_per_step=max_actions_per_step, - tool_calling_method=tool_calling_method + tool_calling_method=tool_calling_method, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) # Add HTML content at the start of the result array html_content = f"