diff --git a/env.example b/env.example index 4c8d355d5..08da27b8d 100644 --- a/env.example +++ b/env.example @@ -164,6 +164,15 @@ LLM_MODEL=gpt-4o LLM_BINDING_HOST=https://api.openai.com/v1 LLM_BINDING_API_KEY=your_api_key +##################################### +### Query-specific LLM Configuration +### Use a more powerful model for queries while keeping extraction economical +##################################### +QUERY_BINDING=openai +QUERY_MODEL=gpt-5 +QUERY_BINDING_HOST=https://api.openai.com/v1 +QUERY_BINDING_API_KEY=your_api_key + ### Optional for Azure # AZURE_OPENAI_API_VERSION=2024-08-01-preview # AZURE_OPENAI_DEPLOYMENT=gpt-4o diff --git a/lightrag/api/config.py b/lightrag/api/config.py index de569f472..450616f98 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -326,6 +326,17 @@ def parse_args() -> argparse.Namespace: # Inject model configuration args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") + args.query_binding = get_env_value("QUERY_BINDING", args.llm_binding) + args.query_model = get_env_value("QUERY_MODEL", args.llm_model) + args.query_binding_host = get_env_value( + "QUERY_BINDING_HOST", + get_default_host(args.query_binding) + if args.query_binding != args.llm_binding + else args.llm_binding_host, + ) + args.query_binding_api_key = get_env_value( + "QUERY_BINDING_API_KEY", args.llm_binding_api_key + ) args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index e8fdb7007..42bcad179 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -325,6 +325,10 @@ async def lifespan(app: FastAPI): # Store background tasks app.state.background_tasks = set() + # Store query LLM function in app state + app.state.query_llm_func = query_llm_func + app.state.query_llm_kwargs = query_llm_kwargs + try: # Initialize database connections await rag.initialize_storages() @@ -497,6 +501,11 @@ async def optimized_azure_openai_model_complete( return optimized_azure_openai_model_complete + llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) + embedding_timeout = get_env_value( + "EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int + ) + def create_llm_model_func(binding: str): """ Create LLM model function based on binding type. @@ -524,6 +533,76 @@ def create_llm_model_func(binding: str): except ImportError as e: raise Exception(f"Failed to import {binding} LLM binding: {e}") + def create_query_llm_func(args, config_cache, llm_timeout): + """ + Create query-specific LLM function if QUERY_BINDING or QUERY_MODEL is different. + Returns tuple of (query_llm_func, query_llm_kwargs). + """ + # Check if query-specific LLM is configured + # Only skip if BOTH binding AND model are the same + if not hasattr(args, "query_binding") or ( + args.query_binding == args.llm_binding + and args.query_model == args.llm_model + ): + logger.info("Using same LLM for both extraction and query") + return None, {} + + logger.info( + f"Creating separate query LLM: {args.query_binding}/{args.query_model}" + ) + + # Create a temporary args object for query LLM + class QueryArgs: + pass + + query_args = QueryArgs() + query_args.llm_binding = args.query_binding + query_args.llm_model = args.query_model + query_args.llm_binding_host = args.query_binding_host + query_args.llm_binding_api_key = args.query_binding_api_key + + # Create query-specific LLM function based on binding type + query_llm_func = None + query_llm_kwargs = {} + + try: + if args.query_binding == "openai": + query_llm_func = create_optimized_openai_llm_func( + config_cache, query_args, llm_timeout + ) + elif args.query_binding == "azure_openai": + query_llm_func = create_optimized_azure_openai_llm_func( + config_cache, query_args, llm_timeout + ) + elif args.query_binding == "ollama": + from lightrag.llm.ollama import ollama_model_complete + + query_llm_func = ollama_model_complete + query_llm_kwargs = create_llm_model_kwargs( + args.query_binding, query_args, llm_timeout + ) + elif args.query_binding == "lollms": + from lightrag.llm.lollms import lollms_model_complete + + query_llm_func = lollms_model_complete + query_llm_kwargs = create_llm_model_kwargs( + args.query_binding, query_args, llm_timeout + ) + elif args.query_binding == "aws_bedrock": + query_llm_func = bedrock_model_complete + else: + raise ValueError(f"Unsupported query binding: {args.query_binding}") + + logger.info(f"Query LLM configured: {args.query_model}") + return query_llm_func, query_llm_kwargs + + except ImportError as e: + raise Exception(f"Failed to import {args.query_binding} LLM binding: {e}") + + query_llm_func, query_llm_kwargs = create_query_llm_func( + args, config_cache, llm_timeout + ) + def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict: """ Create LLM model kwargs based on binding type. diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index f9353dda1..a951ec46b 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -683,6 +683,14 @@ async def stream_generator(): else: first_chunk_time = time.time_ns() + # Check if query-specific LLM is configured + if ( + hasattr(raw_request.app.state, "query_llm_func") + and raw_request.app.state.query_llm_func + ): + query_param.model_func = raw_request.app.state.query_llm_func + logger.debug("Using query-specific LLM for Ollama chat") + # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task match_result = re.search( r"\n\nUSER:", cleaned_query, re.MULTILINE diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 53cc41c00..c90baa9bf 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -6,7 +6,7 @@ import logging from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from lightrag.base import QueryParam from lightrag.api.utils_api import get_combined_auth_dependency from pydantic import BaseModel, Field, field_validator @@ -267,7 +267,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, }, ) - async def query_text(request: QueryRequest): + async def query_text(request: QueryRequest, req: Request): """ Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. @@ -343,6 +343,13 @@ async def query_text(request: QueryRequest): # Force stream=False for /query endpoint regardless of include_references setting param.stream = False + # Use query-specific LLM if available + if ( + hasattr(req.app.state, "query_llm_func") + and req.app.state.query_llm_func + ): + param.model_func = req.app.state.query_llm_func + # Unified approach: always use aquery_llm for both cases result = await rag.aquery_llm(request.query, param=param) @@ -438,7 +445,7 @@ async def query_text(request: QueryRequest): }, }, ) - async def query_text_stream(request: QueryRequest): + async def query_text_stream(request: QueryRequest, req: Request): """ Advanced RAG query endpoint with flexible streaming response. @@ -560,6 +567,13 @@ async def query_text_stream(request: QueryRequest): stream_mode = request.stream if request.stream is not None else True param = request.to_query_params(stream_mode) + # Use query-specific LLM if available + if ( + hasattr(req.app.state, "query_llm_func") + and req.app.state.query_llm_func + ): + param.model_func = req.app.state.query_llm_func + from fastapi.responses import StreamingResponse # Unified approach: always use aquery_llm for all cases @@ -907,7 +921,7 @@ async def stream_generator(): }, }, ) - async def query_data(request: QueryRequest): + async def query_data(request: QueryRequest, req: Request): """ Advanced data retrieval endpoint for structured RAG analysis. @@ -1002,6 +1016,14 @@ async def query_data(request: QueryRequest): """ try: param = request.to_query_params(False) # No streaming for data endpoint + + # Use query-specific LLM if available (for keyword extraction) + if ( + hasattr(req.app.state, "query_llm_func") + and req.app.state.query_llm_func + ): + param.model_func = req.app.state.query_llm_func + response = await rag.aquery_data(request.query, param=param) # aquery_data returns the new format with status, message, data, and metadata