diff --git a/src/utils/utils.py b/src/utils/utils.py index 6fae05bb..7cb59f9f 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -11,6 +11,7 @@ from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI import gradio as gr +from langchain_core.rate_limiters import InMemoryRateLimiter from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama @@ -31,6 +32,15 @@ def get_llm_model(provider: str, **kwargs): :param kwargs: :return: """ + rate_limit_rps = kwargs.get("rate_limit_rps", 1.0) + rate_limit_bucket = kwargs.get("rate_limit_bucket", 10) + # Create rate limiter + rate_limiter = InMemoryRateLimiter( + requests_per_second=rate_limit_rps, + check_every_n_seconds=0.1, + max_bucket_size=rate_limit_bucket + ) + if provider not in ["ollama"]: env_var = f"{provider.upper()}_API_KEY" api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") @@ -49,6 +59,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == 'mistral': if not kwargs.get("base_url", ""): @@ -65,6 +76,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "openai": if not kwargs.get("base_url", ""): @@ -77,6 +89,7 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "deepseek": if not kwargs.get("base_url", ""): @@ -90,19 +103,24 @@ def get_llm_model(provider: str, **kwargs): temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) else: + return ChatOpenAI( model=kwargs.get("model_name", "deepseek-chat"), temperature=kwargs.get("temperature", 0.0), base_url=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "google": + return ChatGoogleGenerativeAI( model=kwargs.get("model_name", "gemini-2.0-flash-exp"), temperature=kwargs.get("temperature", 0.0), google_api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "ollama": if not kwargs.get("base_url", ""): @@ -111,19 +129,23 @@ def get_llm_model(provider: str, **kwargs): base_url = kwargs.get("base_url") if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): + return DeepSeekR1ChatOllama( model=kwargs.get("model_name", "deepseek-r1:14b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), base_url=base_url, + rate_limiter=rate_limiter, ) else: + return ChatOllama( model=kwargs.get("model_name", "qwen2.5:7b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), num_predict=kwargs.get("num_predict", 1024), base_url=base_url, + rate_limiter=rate_limiter, ) elif provider == "azure_openai": if not kwargs.get("base_url", ""): @@ -137,6 +159,7 @@ def get_llm_model(provider: str, **kwargs): api_version=api_version, azure_endpoint=base_url, api_key=api_key, + rate_limiter=rate_limiter, ) elif provider == "alibaba": if not kwargs.get("base_url", ""): diff --git a/webui.py b/webui.py index f61da5bc..db8ce940 100644 --- a/webui.py +++ b/webui.py @@ -119,7 +119,9 @@ async def run_browser_agent( max_steps, use_vision, max_actions_per_step, - tool_calling_method + tool_calling_method, + rate_limit_rps, + rate_limit_bucket ): global _global_agent_state _global_agent_state.clear_stop() # Clear any previous stop requests @@ -149,6 +151,8 @@ async def run_browser_agent( temperature=llm_temperature, base_url=llm_base_url, api_key=llm_api_key, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) if agent_type == "org": final_result, errors, model_actions, model_thoughts, trace_file, history_file = await run_org_agent( @@ -456,7 +460,9 @@ async def run_with_stream( max_steps, use_vision, max_actions_per_step, - tool_calling_method + tool_calling_method, + rate_limit_rps, + rate_limit_bucket ): global _global_agent_state stream_vw = 80 @@ -485,7 +491,9 @@ async def run_with_stream( max_steps=max_steps, use_vision=use_vision, max_actions_per_step=max_actions_per_step, - tool_calling_method=tool_calling_method + tool_calling_method=tool_calling_method, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) # Add HTML content at the start of the result array html_content = f"

Using browser...

" @@ -518,7 +526,9 @@ async def run_with_stream( max_steps=max_steps, use_vision=use_vision, max_actions_per_step=max_actions_per_step, - tool_calling_method=tool_calling_method + tool_calling_method=tool_calling_method, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) ) @@ -632,7 +642,7 @@ async def close_global_browser(): await _global_browser.close() _global_browser = None -async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless): +async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless, rate_limit_rps, rate_limit_bucket): from src.utils.deep_research import deep_research global _global_agent_state @@ -646,6 +656,8 @@ async def run_deep_search(research_task, max_search_iteration_input, max_query_p temperature=llm_temperature, base_url=llm_base_url, api_key=llm_api_key, + rate_limit_rps=rate_limit_rps, + rate_limit_bucket=rate_limit_bucket ) markdown_content, file_path = await deep_research(research_task, llm, _global_agent_state, max_search_iterations=max_search_iteration_input, @@ -775,6 +787,19 @@ def create_ui(config, theme_name="Ocean"): value=config['llm_api_key'], info="Your API key (leave blank to use .env)" ) + with gr.Row(): + rate_limit_rps = gr.Number( + label="Requests/sec", + value=config.get('rate_limit_rps', 1), + precision=1, + info="Max requests per second" + ) + rate_limit_bucket = gr.Number( + label="Max Bucket Size", + value=config.get('rate_limit_bucket', 10), + precision=0, + info="Maximum burst capacity" + ) # Change event to update context length slider def update_llm_num_ctx_visibility(llm_provider): @@ -932,7 +957,8 @@ def update_llm_num_ctx_visibility(llm_provider): agent_type, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_agent_history_path, save_trace_path, # Include the new path - enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method + enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method, + rate_limit_rps, rate_limit_bucket ], outputs=[ browser_view, # Browser view @@ -951,7 +977,7 @@ def update_llm_num_ctx_visibility(llm_provider): # Run Deep Research research_button.click( fn=run_deep_search, - inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless], + inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless, rate_limit_rps, rate_limit_bucket], outputs=[markdown_output_display, markdown_download, stop_research_button, research_button] ) # Bind the stop button click event after errors_output is defined